channels.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. """Base classes to manage a Client's interaction with a running kernel"""
  2. # Copyright (c) Jupyter Development Team.
  3. # Distributed under the terms of the Modified BSD License.
  4. import asyncio
  5. import atexit
  6. import time
  7. import typing as t
  8. from queue import Empty
  9. from threading import Event, Thread
  10. import zmq.asyncio
  11. from jupyter_core.utils import ensure_async
  12. from ._version import protocol_version_info
  13. from .channelsabc import HBChannelABC
  14. from .session import Session
  15. # import ZMQError in top-level namespace, to avoid ugly attribute-error messages
  16. # during garbage collection of threads at exit
  17. # -----------------------------------------------------------------------------
  18. # Constants and exceptions
  19. # -----------------------------------------------------------------------------
  20. major_protocol_version = protocol_version_info[0]
  21. class InvalidPortNumber(Exception): # noqa
  22. """An exception raised for an invalid port number."""
  23. pass
  24. class HBChannel(Thread):
  25. """The heartbeat channel which monitors the kernel heartbeat.
  26. Note that the heartbeat channel is paused by default. As long as you start
  27. this channel, the kernel manager will ensure that it is paused and un-paused
  28. as appropriate.
  29. """
  30. session = None
  31. socket = None
  32. address = None
  33. _exiting = False
  34. time_to_dead: float = 1.0
  35. _running = None
  36. _pause = None
  37. _beating = None
  38. def __init__(
  39. self,
  40. context: zmq.Context | None = None,
  41. session: Session | None = None,
  42. address: t.Union[t.Tuple[str, int], str] = "",
  43. ) -> None:
  44. """Create the heartbeat monitor thread.
  45. Parameters
  46. ----------
  47. context : :class:`zmq.Context`
  48. The ZMQ context to use.
  49. session : :class:`session.Session`
  50. The session to use.
  51. address : zmq url
  52. Standard (ip, port) tuple that the kernel is listening on.
  53. """
  54. super().__init__()
  55. self.daemon = True
  56. self.context = context
  57. self.session = session
  58. if isinstance(address, tuple):
  59. if address[1] == 0:
  60. message = "The port number for a channel cannot be 0."
  61. raise InvalidPortNumber(message)
  62. address_str = "tcp://%s:%i" % address
  63. else:
  64. address_str = address
  65. self.address = address_str
  66. # running is False until `.start()` is called
  67. self._running = False
  68. self._exit = Event()
  69. # don't start paused
  70. self._pause = False
  71. self.poller = zmq.Poller()
  72. @staticmethod
  73. @atexit.register
  74. def _notice_exit() -> None:
  75. # Class definitions can be torn down during interpreter shutdown.
  76. # We only need to set _exiting flag if this hasn't happened.
  77. if HBChannel is not None:
  78. HBChannel._exiting = True
  79. def _create_socket(self) -> None:
  80. if self.socket is not None:
  81. # close previous socket, before opening a new one
  82. self.poller.unregister(self.socket) # type:ignore[unreachable]
  83. self.socket.close()
  84. assert self.context is not None
  85. self.socket = self.context.socket(zmq.REQ)
  86. self.socket.linger = 1000
  87. assert self.address is not None
  88. self.socket.connect(self.address)
  89. self.poller.register(self.socket, zmq.POLLIN)
  90. async def _async_run(self) -> None:
  91. """The thread's main activity. Call start() instead."""
  92. self._create_socket()
  93. self._running = True
  94. self._beating = True
  95. assert self.socket is not None
  96. while self._running:
  97. if self._pause:
  98. # just sleep, and skip the rest of the loop
  99. self._exit.wait(self.time_to_dead)
  100. continue
  101. since_last_heartbeat = 0.0
  102. # no need to catch EFSM here, because the previous event was
  103. # either a recv or connect, which cannot be followed by EFSM)
  104. await ensure_async(self.socket.send(b"ping"))
  105. request_time = time.time()
  106. # Wait until timeout
  107. self._exit.wait(self.time_to_dead)
  108. # poll(0) means return immediately (see http://api.zeromq.org/2-1:zmq-poll)
  109. self._beating = bool(self.poller.poll(0))
  110. if self._beating:
  111. # the poll above guarantees we have something to recv
  112. await ensure_async(self.socket.recv())
  113. continue
  114. elif self._running:
  115. # nothing was received within the time limit, signal heart failure
  116. since_last_heartbeat = time.time() - request_time
  117. self.call_handlers(since_last_heartbeat)
  118. # and close/reopen the socket, because the REQ/REP cycle has been broken
  119. self._create_socket()
  120. continue
  121. def run(self) -> None:
  122. """Run the heartbeat thread."""
  123. loop = asyncio.new_event_loop()
  124. asyncio.set_event_loop(loop)
  125. try:
  126. loop.run_until_complete(self._async_run())
  127. finally:
  128. loop.close()
  129. def pause(self) -> None:
  130. """Pause the heartbeat."""
  131. self._pause = True
  132. def unpause(self) -> None:
  133. """Unpause the heartbeat."""
  134. self._pause = False
  135. def is_beating(self) -> bool:
  136. """Is the heartbeat running and responsive (and not paused)."""
  137. if self.is_alive() and not self._pause and self._beating: # noqa
  138. return True
  139. else:
  140. return False
  141. def stop(self) -> None:
  142. """Stop the channel's event loop and join its thread."""
  143. self._running = False
  144. self._exit.set()
  145. self.join()
  146. self.close()
  147. def close(self) -> None:
  148. """Close the heartbeat thread."""
  149. if self.socket is not None:
  150. try:
  151. self.socket.close(linger=0)
  152. except Exception:
  153. pass
  154. self.socket = None
  155. def call_handlers(self, since_last_heartbeat: float) -> None:
  156. """This method is called in the ioloop thread when a message arrives.
  157. Subclasses should override this method to handle incoming messages.
  158. It is important to remember that this method is called in the thread
  159. so that some logic must be done to ensure that the application level
  160. handlers are called in the application thread.
  161. """
  162. pass
  163. HBChannelABC.register(HBChannel)
  164. class ZMQSocketChannel:
  165. """A ZMQ socket wrapper"""
  166. def __init__(self, socket: zmq.Socket, session: Session, loop: t.Any = None) -> None:
  167. """Create a channel.
  168. Parameters
  169. ----------
  170. socket : :class:`zmq.Socket`
  171. The ZMQ socket to use.
  172. session : :class:`session.Session`
  173. The session to use.
  174. loop
  175. Unused here, for other implementations
  176. """
  177. super().__init__()
  178. self.socket: zmq.Socket | None = socket
  179. self.session = session
  180. def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]:
  181. assert self.socket is not None
  182. msg = self.socket.recv_multipart(**kwargs)
  183. _ident, smsg = self.session.feed_identities(msg)
  184. return self.session.deserialize(smsg)
  185. def get_msg(self, timeout: float | None = None) -> t.Dict[str, t.Any]:
  186. """Gets a message if there is one that is ready."""
  187. assert self.socket is not None
  188. timeout_ms = None if timeout is None else int(timeout * 1000) # seconds to ms
  189. ready = self.socket.poll(timeout_ms)
  190. if ready:
  191. res = self._recv()
  192. return res
  193. else:
  194. raise Empty
  195. def get_msgs(self) -> t.List[t.Dict[str, t.Any]]:
  196. """Get all messages that are currently ready."""
  197. msgs = []
  198. while True:
  199. try:
  200. msgs.append(self.get_msg())
  201. except Empty:
  202. break
  203. return msgs
  204. def msg_ready(self) -> bool:
  205. """Is there a message that has been received?"""
  206. assert self.socket is not None
  207. return bool(self.socket.poll(timeout=0))
  208. def close(self) -> None:
  209. """Close the socket channel."""
  210. if self.socket is not None:
  211. try:
  212. self.socket.close(linger=0)
  213. except Exception:
  214. pass
  215. self.socket = None
  216. stop = close
  217. def is_alive(self) -> bool:
  218. """Test whether the channel is alive."""
  219. return self.socket is not None
  220. def send(self, msg: t.Dict[str, t.Any]) -> None:
  221. """Pass a message to the ZMQ socket to send"""
  222. assert self.socket is not None
  223. self.session.send(self.socket, msg)
  224. def start(self) -> None:
  225. """Start the socket channel."""
  226. pass
  227. class AsyncZMQSocketChannel(ZMQSocketChannel):
  228. """A ZMQ socket in an async API"""
  229. socket: zmq.asyncio.Socket
  230. def __init__(self, socket: zmq.asyncio.Socket, session: Session, loop: t.Any = None) -> None:
  231. """Create a channel.
  232. Parameters
  233. ----------
  234. socket : :class:`zmq.asyncio.Socket`
  235. The ZMQ socket to use.
  236. session : :class:`session.Session`
  237. The session to use.
  238. loop
  239. Unused here, for other implementations
  240. """
  241. if not isinstance(socket, zmq.asyncio.Socket):
  242. msg = "Socket must be asyncio" # type:ignore[unreachable]
  243. raise ValueError(msg)
  244. super().__init__(socket, session)
  245. async def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]: # type:ignore[override]
  246. assert self.socket is not None
  247. msg = await self.socket.recv_multipart(**kwargs)
  248. _, smsg = self.session.feed_identities(msg)
  249. return self.session.deserialize(smsg)
  250. async def get_msg( # type:ignore[override]
  251. self, timeout: float | None = None
  252. ) -> t.Dict[str, t.Any]:
  253. """Gets a message if there is one that is ready."""
  254. assert self.socket is not None
  255. timeout_ms = None if timeout is None else int(timeout * 1000) # seconds to ms
  256. ready = await self.socket.poll(timeout_ms)
  257. if ready:
  258. res = await self._recv()
  259. return res
  260. else:
  261. raise Empty
  262. async def get_msgs(self) -> t.List[t.Dict[str, t.Any]]: # type:ignore[override]
  263. """Get all messages that are currently ready."""
  264. msgs = []
  265. while True:
  266. try:
  267. msgs.append(await self.get_msg())
  268. except Empty:
  269. break
  270. return msgs
  271. async def msg_ready(self) -> bool: # type:ignore[override]
  272. """Is there a message that has been received?"""
  273. assert self.socket is not None
  274. return bool(await self.socket.poll(timeout=0))