| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331 |
- """Base classes to manage a Client's interaction with a running kernel"""
- # Copyright (c) Jupyter Development Team.
- # Distributed under the terms of the Modified BSD License.
- import asyncio
- import atexit
- import time
- import typing as t
- from queue import Empty
- from threading import Event, Thread
- import zmq.asyncio
- from jupyter_core.utils import ensure_async
- from ._version import protocol_version_info
- from .channelsabc import HBChannelABC
- from .session import Session
- # import ZMQError in top-level namespace, to avoid ugly attribute-error messages
- # during garbage collection of threads at exit
- # -----------------------------------------------------------------------------
- # Constants and exceptions
- # -----------------------------------------------------------------------------
- major_protocol_version = protocol_version_info[0]
- class InvalidPortNumber(Exception): # noqa
- """An exception raised for an invalid port number."""
- pass
- class HBChannel(Thread):
- """The heartbeat channel which monitors the kernel heartbeat.
- Note that the heartbeat channel is paused by default. As long as you start
- this channel, the kernel manager will ensure that it is paused and un-paused
- as appropriate.
- """
- session = None
- socket = None
- address = None
- _exiting = False
- time_to_dead: float = 1.0
- _running = None
- _pause = None
- _beating = None
- def __init__(
- self,
- context: zmq.Context | None = None,
- session: Session | None = None,
- address: t.Union[t.Tuple[str, int], str] = "",
- ) -> None:
- """Create the heartbeat monitor thread.
- Parameters
- ----------
- context : :class:`zmq.Context`
- The ZMQ context to use.
- session : :class:`session.Session`
- The session to use.
- address : zmq url
- Standard (ip, port) tuple that the kernel is listening on.
- """
- super().__init__()
- self.daemon = True
- self.context = context
- self.session = session
- if isinstance(address, tuple):
- if address[1] == 0:
- message = "The port number for a channel cannot be 0."
- raise InvalidPortNumber(message)
- address_str = "tcp://%s:%i" % address
- else:
- address_str = address
- self.address = address_str
- # running is False until `.start()` is called
- self._running = False
- self._exit = Event()
- # don't start paused
- self._pause = False
- self.poller = zmq.Poller()
- @staticmethod
- @atexit.register
- def _notice_exit() -> None:
- # Class definitions can be torn down during interpreter shutdown.
- # We only need to set _exiting flag if this hasn't happened.
- if HBChannel is not None:
- HBChannel._exiting = True
- def _create_socket(self) -> None:
- if self.socket is not None:
- # close previous socket, before opening a new one
- self.poller.unregister(self.socket) # type:ignore[unreachable]
- self.socket.close()
- assert self.context is not None
- self.socket = self.context.socket(zmq.REQ)
- self.socket.linger = 1000
- assert self.address is not None
- self.socket.connect(self.address)
- self.poller.register(self.socket, zmq.POLLIN)
- async def _async_run(self) -> None:
- """The thread's main activity. Call start() instead."""
- self._create_socket()
- self._running = True
- self._beating = True
- assert self.socket is not None
- while self._running:
- if self._pause:
- # just sleep, and skip the rest of the loop
- self._exit.wait(self.time_to_dead)
- continue
- since_last_heartbeat = 0.0
- # no need to catch EFSM here, because the previous event was
- # either a recv or connect, which cannot be followed by EFSM)
- await ensure_async(self.socket.send(b"ping"))
- request_time = time.time()
- # Wait until timeout
- self._exit.wait(self.time_to_dead)
- # poll(0) means return immediately (see http://api.zeromq.org/2-1:zmq-poll)
- self._beating = bool(self.poller.poll(0))
- if self._beating:
- # the poll above guarantees we have something to recv
- await ensure_async(self.socket.recv())
- continue
- elif self._running:
- # nothing was received within the time limit, signal heart failure
- since_last_heartbeat = time.time() - request_time
- self.call_handlers(since_last_heartbeat)
- # and close/reopen the socket, because the REQ/REP cycle has been broken
- self._create_socket()
- continue
- def run(self) -> None:
- """Run the heartbeat thread."""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(self._async_run())
- finally:
- loop.close()
- def pause(self) -> None:
- """Pause the heartbeat."""
- self._pause = True
- def unpause(self) -> None:
- """Unpause the heartbeat."""
- self._pause = False
- def is_beating(self) -> bool:
- """Is the heartbeat running and responsive (and not paused)."""
- if self.is_alive() and not self._pause and self._beating: # noqa
- return True
- else:
- return False
- def stop(self) -> None:
- """Stop the channel's event loop and join its thread."""
- self._running = False
- self._exit.set()
- self.join()
- self.close()
- def close(self) -> None:
- """Close the heartbeat thread."""
- if self.socket is not None:
- try:
- self.socket.close(linger=0)
- except Exception:
- pass
- self.socket = None
- def call_handlers(self, since_last_heartbeat: float) -> None:
- """This method is called in the ioloop thread when a message arrives.
- Subclasses should override this method to handle incoming messages.
- It is important to remember that this method is called in the thread
- so that some logic must be done to ensure that the application level
- handlers are called in the application thread.
- """
- pass
- HBChannelABC.register(HBChannel)
- class ZMQSocketChannel:
- """A ZMQ socket wrapper"""
- def __init__(self, socket: zmq.Socket, session: Session, loop: t.Any = None) -> None:
- """Create a channel.
- Parameters
- ----------
- socket : :class:`zmq.Socket`
- The ZMQ socket to use.
- session : :class:`session.Session`
- The session to use.
- loop
- Unused here, for other implementations
- """
- super().__init__()
- self.socket: zmq.Socket | None = socket
- self.session = session
- def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]:
- assert self.socket is not None
- msg = self.socket.recv_multipart(**kwargs)
- _ident, smsg = self.session.feed_identities(msg)
- return self.session.deserialize(smsg)
- def get_msg(self, timeout: float | None = None) -> t.Dict[str, t.Any]:
- """Gets a message if there is one that is ready."""
- assert self.socket is not None
- timeout_ms = None if timeout is None else int(timeout * 1000) # seconds to ms
- ready = self.socket.poll(timeout_ms)
- if ready:
- res = self._recv()
- return res
- else:
- raise Empty
- def get_msgs(self) -> t.List[t.Dict[str, t.Any]]:
- """Get all messages that are currently ready."""
- msgs = []
- while True:
- try:
- msgs.append(self.get_msg())
- except Empty:
- break
- return msgs
- def msg_ready(self) -> bool:
- """Is there a message that has been received?"""
- assert self.socket is not None
- return bool(self.socket.poll(timeout=0))
- def close(self) -> None:
- """Close the socket channel."""
- if self.socket is not None:
- try:
- self.socket.close(linger=0)
- except Exception:
- pass
- self.socket = None
- stop = close
- def is_alive(self) -> bool:
- """Test whether the channel is alive."""
- return self.socket is not None
- def send(self, msg: t.Dict[str, t.Any]) -> None:
- """Pass a message to the ZMQ socket to send"""
- assert self.socket is not None
- self.session.send(self.socket, msg)
- def start(self) -> None:
- """Start the socket channel."""
- pass
- class AsyncZMQSocketChannel(ZMQSocketChannel):
- """A ZMQ socket in an async API"""
- socket: zmq.asyncio.Socket
- def __init__(self, socket: zmq.asyncio.Socket, session: Session, loop: t.Any = None) -> None:
- """Create a channel.
- Parameters
- ----------
- socket : :class:`zmq.asyncio.Socket`
- The ZMQ socket to use.
- session : :class:`session.Session`
- The session to use.
- loop
- Unused here, for other implementations
- """
- if not isinstance(socket, zmq.asyncio.Socket):
- msg = "Socket must be asyncio" # type:ignore[unreachable]
- raise ValueError(msg)
- super().__init__(socket, session)
- async def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]: # type:ignore[override]
- assert self.socket is not None
- msg = await self.socket.recv_multipart(**kwargs)
- _, smsg = self.session.feed_identities(msg)
- return self.session.deserialize(smsg)
- async def get_msg( # type:ignore[override]
- self, timeout: float | None = None
- ) -> t.Dict[str, t.Any]:
- """Gets a message if there is one that is ready."""
- assert self.socket is not None
- timeout_ms = None if timeout is None else int(timeout * 1000) # seconds to ms
- ready = await self.socket.poll(timeout_ms)
- if ready:
- res = await self._recv()
- return res
- else:
- raise Empty
- async def get_msgs(self) -> t.List[t.Dict[str, t.Any]]: # type:ignore[override]
- """Get all messages that are currently ready."""
- msgs = []
- while True:
- try:
- msgs.append(await self.get_msg())
- except Empty:
- break
- return msgs
- async def msg_ready(self) -> bool: # type:ignore[override]
- """Is there a message that has been received?"""
- assert self.socket is not None
- return bool(await self.socket.poll(timeout=0))
|