threaded.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. """Defines a KernelClient that provides thread-safe sockets with async callbacks on message
  2. replies.
  3. """
  4. import asyncio
  5. import atexit
  6. import time
  7. from concurrent.futures import Future
  8. from functools import partial
  9. from threading import Thread
  10. from typing import Any
  11. import zmq
  12. from tornado.ioloop import IOLoop
  13. from traitlets import Instance, Type
  14. from traitlets.log import get_logger
  15. from zmq.eventloop import zmqstream
  16. from .channels import HBChannel
  17. from .client import KernelClient
  18. from .session import Session
  19. # Local imports
  20. # import ZMQError in top-level namespace, to avoid ugly attribute-error messages
  21. # during garbage collection of threads at exit
  22. class ThreadedZMQSocketChannel:
  23. """A ZMQ socket invoking a callback in the ioloop"""
  24. session = None
  25. socket = None
  26. ioloop = None
  27. stream = None
  28. _inspect = None
  29. def __init__(
  30. self,
  31. socket: zmq.Socket | None,
  32. session: Session | None,
  33. loop: IOLoop | None,
  34. ) -> None:
  35. """Create a channel.
  36. Parameters
  37. ----------
  38. socket : :class:`zmq.Socket`
  39. The ZMQ socket to use.
  40. session : :class:`session.Session`
  41. The session to use.
  42. loop
  43. A tornado ioloop to connect the socket to using a ZMQStream
  44. """
  45. super().__init__()
  46. self.socket = socket
  47. self.session = session
  48. self.ioloop = loop
  49. f: Future = Future()
  50. def setup_stream() -> None:
  51. try:
  52. assert self.socket is not None
  53. self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
  54. self.stream.on_recv(self._handle_recv)
  55. except Exception as e:
  56. f.set_exception(e)
  57. else:
  58. f.set_result(None)
  59. assert self.ioloop is not None
  60. self.ioloop.add_callback(setup_stream)
  61. # don't wait forever, raise any errors
  62. f.result(timeout=10)
  63. _is_alive = False
  64. def is_alive(self) -> bool:
  65. """Whether the channel is alive."""
  66. return self._is_alive
  67. def start(self) -> None:
  68. """Start the channel."""
  69. self._is_alive = True
  70. def stop(self) -> None:
  71. """Stop the channel."""
  72. self._is_alive = False
  73. def close(self) -> None:
  74. """Close the channel."""
  75. if self.stream is not None and self.ioloop is not None:
  76. # c.f.Future for threadsafe results
  77. f: Future = Future()
  78. def close_stream() -> None:
  79. try:
  80. if self.stream is not None:
  81. self.stream.close(linger=0)
  82. self.stream = None
  83. except Exception as e:
  84. f.set_exception(e)
  85. else:
  86. f.set_result(None)
  87. self.ioloop.add_callback(close_stream)
  88. # wait for result
  89. try:
  90. f.result(timeout=5)
  91. except Exception as e:
  92. log = get_logger()
  93. msg = f"Error closing stream {self.stream}: {e}"
  94. log.warning(msg, RuntimeWarning, stacklevel=2)
  95. if self.socket is not None:
  96. try:
  97. self.socket.close(linger=0)
  98. except Exception:
  99. pass
  100. self.socket = None
  101. def send(self, msg: dict[str, Any]) -> None:
  102. """Queue a message to be sent from the IOLoop's thread.
  103. Parameters
  104. ----------
  105. msg : message to send
  106. This is threadsafe, as it uses IOLoop.add_callback to give the loop's
  107. thread control of the action.
  108. """
  109. def thread_send() -> None:
  110. assert self.session is not None
  111. self.session.send(self.stream, msg)
  112. assert self.ioloop is not None
  113. self.ioloop.add_callback(thread_send)
  114. def _handle_recv(self, msg_list: list) -> None:
  115. """Callback for stream.on_recv.
  116. Unpacks message, and calls handlers with it.
  117. """
  118. assert self.ioloop is not None
  119. assert self.session is not None
  120. _ident, smsg = self.session.feed_identities(msg_list)
  121. msg = self.session.deserialize(smsg)
  122. # let client inspect messages
  123. if self._inspect:
  124. self._inspect(msg) # type:ignore[unreachable]
  125. self.call_handlers(msg)
  126. def call_handlers(self, msg: dict[str, Any]) -> None:
  127. """This method is called in the ioloop thread when a message arrives.
  128. Subclasses should override this method to handle incoming messages.
  129. It is important to remember that this method is called in the thread
  130. so that some logic must be done to ensure that the application level
  131. handlers are called in the application thread.
  132. """
  133. pass
  134. def process_events(self) -> None:
  135. """Subclasses should override this with a method
  136. processing any pending GUI events.
  137. """
  138. pass
  139. def flush(self, timeout: float = 1.0) -> None:
  140. """Immediately processes all pending messages on this channel.
  141. This is only used for the IOPub channel.
  142. Callers should use this method to ensure that :meth:`call_handlers`
  143. has been called for all messages that have been received on the
  144. 0MQ SUB socket of this channel.
  145. This method is thread safe.
  146. Parameters
  147. ----------
  148. timeout : float, optional
  149. The maximum amount of time to spend flushing, in seconds. The
  150. default is one second.
  151. """
  152. # We do the IOLoop callback process twice to ensure that the IOLoop
  153. # gets to perform at least one full poll.
  154. stop_time = time.monotonic() + timeout
  155. assert self.ioloop is not None
  156. if self.stream is None or self.stream.closed():
  157. # don't bother scheduling flush on a thread if we're closed
  158. _msg = "Attempt to flush closed stream"
  159. raise OSError(_msg)
  160. def flush(f: Any) -> None:
  161. try:
  162. self._flush()
  163. except Exception as e:
  164. f.set_exception(e)
  165. else:
  166. f.set_result(None)
  167. for _ in range(2):
  168. f: Future = Future()
  169. self.ioloop.add_callback(partial(flush, f))
  170. # wait for async flush, re-raise any errors
  171. timeout = max(stop_time - time.monotonic(), 0)
  172. try:
  173. f.result(max(stop_time - time.monotonic(), 0))
  174. except TimeoutError:
  175. # flush with a timeout means stop waiting, not raise
  176. return
  177. def _flush(self) -> None:
  178. """Callback for :method:`self.flush`."""
  179. # Race condition: flush() checks stream validity then schedules this
  180. # callback on the ioloop thread. Between scheduling and execution,
  181. # stop_channels() may close the stream (e.g., during teardown).
  182. # Handle gracefully rather than asserting, since this is an expected
  183. # edge case during shutdown, not a programming error.
  184. if self.stream is None or self.stream.closed():
  185. return
  186. self.stream.flush()
  187. self._flushed = True
  188. class IOLoopThread(Thread):
  189. """Run a pyzmq ioloop in a thread to send and receive messages"""
  190. _exiting = False
  191. ioloop = None
  192. def __init__(self) -> None:
  193. """Initialize an io loop thread."""
  194. super().__init__()
  195. self.daemon = True
  196. # Instance variable to track exit state for this specific thread.
  197. # The class variable _exiting is used by _notice_exit for interpreter shutdown.
  198. # Without this instance variable, stopping one IOLoopThread sets the class-level
  199. # _exiting = True, causing all subsequent IOLoopThread instances to exit immediately
  200. # in _async_run(). This breaks sequential kernel usage (e.g., qtconsole tests).
  201. self._exiting = False
  202. @staticmethod
  203. @atexit.register
  204. def _notice_exit() -> None:
  205. # Class definitions can be torn down during interpreter shutdown.
  206. # We only need to set _exiting flag if this hasn't happened.
  207. if IOLoopThread is not None:
  208. IOLoopThread._exiting = True
  209. def start(self) -> None:
  210. """Start the IOLoop thread
  211. Don't return until self.ioloop is defined,
  212. which is created in the thread
  213. """
  214. self._start_future: Future = Future()
  215. Thread.start(self)
  216. # wait for start, re-raise any errors
  217. self._start_future.result(timeout=10)
  218. def run(self) -> None:
  219. """Run my loop, ignoring EINTR events in the poller"""
  220. try:
  221. loop = asyncio.new_event_loop()
  222. asyncio.set_event_loop(loop)
  223. async def assign_ioloop() -> None:
  224. self.ioloop = IOLoop.current()
  225. loop.run_until_complete(assign_ioloop())
  226. except Exception as e:
  227. self._start_future.set_exception(e)
  228. else:
  229. self._start_future.set_result(None)
  230. try:
  231. loop.run_until_complete(self._async_run())
  232. finally:
  233. loop.close()
  234. async def _async_run(self) -> None:
  235. """Run forever (until self._exiting is set)"""
  236. while not self._exiting:
  237. await asyncio.sleep(1)
  238. def stop(self) -> None:
  239. """Stop the channel's event loop and join its thread.
  240. This calls :meth:`~threading.Thread.join` and returns when the thread
  241. terminates. :class:`RuntimeError` will be raised if
  242. :meth:`~threading.Thread.start` is called again.
  243. """
  244. self._exiting = True
  245. self.join()
  246. self.close()
  247. self.ioloop = None
  248. def __del__(self) -> None:
  249. self.close()
  250. def close(self) -> None:
  251. """Close the io loop thread."""
  252. if self.ioloop is not None:
  253. try:
  254. self.ioloop.close(all_fds=True)
  255. except Exception:
  256. pass
  257. class ThreadedKernelClient(KernelClient):
  258. """A KernelClient that provides thread-safe sockets with async callbacks on message replies."""
  259. @property
  260. def ioloop(self) -> IOLoop | None: # type:ignore[override]
  261. if self.ioloop_thread:
  262. return self.ioloop_thread.ioloop
  263. return None
  264. ioloop_thread = Instance(IOLoopThread, allow_none=True)
  265. def start_channels(
  266. self,
  267. shell: bool = True,
  268. iopub: bool = True,
  269. stdin: bool = True,
  270. hb: bool = True,
  271. control: bool = True,
  272. ) -> None:
  273. """Start the channels on the client."""
  274. self.ioloop_thread = IOLoopThread()
  275. self.ioloop_thread.start()
  276. if shell:
  277. self.shell_channel._inspect = self._check_kernel_info_reply
  278. super().start_channels(shell, iopub, stdin, hb, control)
  279. def _check_kernel_info_reply(self, msg: dict[str, Any]) -> None:
  280. """This is run in the ioloop thread when the kernel info reply is received"""
  281. if msg["msg_type"] == "kernel_info_reply":
  282. self._handle_kernel_info_reply(msg)
  283. self.shell_channel._inspect = None
  284. def stop_channels(self) -> None:
  285. """Stop the channels on the client."""
  286. # Close channel streams while ioloop is still running
  287. # This must happen before stopping the ioloop thread, otherwise
  288. # the ZMQ streams can't be properly unregistered from the event loop
  289. if self.ioloop_thread and self.ioloop_thread.is_alive():
  290. if self._shell_channel is not None:
  291. self._shell_channel.close()
  292. if self._iopub_channel is not None:
  293. self._iopub_channel.close()
  294. if self._stdin_channel is not None:
  295. self._stdin_channel.close()
  296. if self._control_channel is not None:
  297. self._control_channel.close()
  298. super().stop_channels()
  299. if self.ioloop_thread and self.ioloop_thread.is_alive():
  300. self.ioloop_thread.stop()
  301. iopub_channel_class = Type(ThreadedZMQSocketChannel) # type:ignore[assignment]
  302. shell_channel_class = Type(ThreadedZMQSocketChannel) # type:ignore[assignment]
  303. stdin_channel_class = Type(ThreadedZMQSocketChannel) # type:ignore[assignment]
  304. hb_channel_class = Type(HBChannel) # type:ignore[assignment]
  305. control_channel_class = Type(ThreadedZMQSocketChannel) # type:ignore[assignment]
  306. def is_alive(self) -> bool:
  307. """Is the kernel process still running?"""
  308. if self._hb_channel is not None:
  309. # We don't have access to the KernelManager,
  310. # so we use the heartbeat.
  311. return self._hb_channel.is_beating()
  312. # no heartbeat and not local, we can't tell if it's running,
  313. # so naively return True
  314. return True