subshell_manager.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. """Manager of subshells in a kernel."""
  2. from __future__ import annotations
  3. import asyncio
  4. import json
  5. import typing as t
  6. import uuid
  7. from functools import partial
  8. from threading import Lock, current_thread
  9. import zmq
  10. from tornado.ioloop import IOLoop
  11. from .socket_pair import SocketPair
  12. from .subshell import SubshellThread
  13. from .thread import SHELL_CHANNEL_THREAD_NAME
  14. from .utils import _async_in_context
  15. class SubshellManager:
  16. """A manager of subshells.
  17. Controls the lifetimes of subshell threads and their associated ZMQ sockets and
  18. streams. Runs mostly in the shell channel thread.
  19. Care needed with threadsafe access here. All write access to the cache occurs in
  20. the shell channel thread so there is only ever one write access at any one time.
  21. Reading of cache information can be performed by other threads, so all reads are
  22. protected by a lock so that they are atomic.
  23. Sending reply messages via the shell_socket is wrapped by another lock to protect
  24. against multiple subshells attempting to send at the same time.
  25. .. versionadded:: 7
  26. """
  27. def __init__(
  28. self,
  29. context: zmq.Context[t.Any],
  30. shell_channel_io_loop: IOLoop,
  31. shell_socket: zmq.Socket[t.Any],
  32. ):
  33. """Initialize the subshell manager."""
  34. self._parent_thread = current_thread()
  35. self._context: zmq.Context[t.Any] = context
  36. self._shell_channel_io_loop = shell_channel_io_loop
  37. self._shell_socket = shell_socket
  38. self._cache: dict[str, SubshellThread] = {}
  39. self._lock_cache = Lock() # Sync lock across threads when accessing cache.
  40. # Inproc socket pair for communication from control thread to shell channel thread,
  41. # such as for create_subshell_request messages. Reply messages are returned straight away.
  42. self.control_to_shell_channel = SocketPair(self._context, "control")
  43. self.control_to_shell_channel.on_recv(
  44. self._shell_channel_io_loop, self._process_control_request, copy=True
  45. )
  46. # Inproc socket pair for communication from shell channel thread to main thread,
  47. # such as for execute_request messages.
  48. self._shell_channel_to_main = SocketPair(self._context, "main")
  49. # Inproc socket pair for communication from main thread to shell channel thread.
  50. # such as for execute_reply messages.
  51. self._main_to_shell_channel = SocketPair(self._context, "main-reverse")
  52. self._main_to_shell_channel.on_recv(
  53. self._shell_channel_io_loop, self._send_on_shell_channel
  54. )
  55. def close(self) -> None:
  56. """Stop all subshells and close all resources."""
  57. assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
  58. with self._lock_cache:
  59. while True:
  60. try:
  61. _, subshell_thread = self._cache.popitem()
  62. except KeyError:
  63. break
  64. self._stop_subshell(subshell_thread)
  65. self.control_to_shell_channel.close()
  66. self._main_to_shell_channel.close()
  67. self._shell_channel_to_main.close()
  68. def get_shell_channel_to_subshell_pair(self, subshell_id: str | None) -> SocketPair:
  69. """Return the inproc socket pair used to send messages from the shell channel
  70. to a particular subshell or main shell."""
  71. if subshell_id is None:
  72. return self._shell_channel_to_main
  73. with self._lock_cache:
  74. return self._cache[subshell_id].shell_channel_to_subshell
  75. def get_subshell_to_shell_channel_socket(self, subshell_id: str | None) -> zmq.Socket[t.Any]:
  76. """Return the socket used by a particular subshell or main shell to send
  77. messages to the shell channel.
  78. """
  79. if subshell_id is None:
  80. return self._main_to_shell_channel.from_socket
  81. with self._lock_cache:
  82. return self._cache[subshell_id].subshell_to_shell_channel.from_socket
  83. def get_shell_channel_to_subshell_socket(self, subshell_id: str | None) -> zmq.Socket[t.Any]:
  84. """Return the socket used by the shell channel to send messages to a particular
  85. subshell or main shell.
  86. """
  87. return self.get_shell_channel_to_subshell_pair(subshell_id).from_socket
  88. def get_subshell_aborting(self, subshell_id: str) -> bool:
  89. """Get the boolean aborting flag of the specified subshell."""
  90. with self._lock_cache:
  91. return self._cache[subshell_id].aborting
  92. def get_subshell_asyncio_lock(self, subshell_id: str) -> asyncio.Lock:
  93. """Return the asyncio lock belonging to the specified subshell."""
  94. with self._lock_cache:
  95. return self._cache[subshell_id].asyncio_lock
  96. def list_subshell(self) -> list[str]:
  97. """Return list of current subshell ids.
  98. Can be called by any subshell using %subshell magic.
  99. """
  100. with self._lock_cache:
  101. return list(self._cache)
  102. def set_on_recv_callback(self, on_recv_callback):
  103. """Set the callback used by the main shell and all subshells to receive
  104. messages sent from the shell channel thread.
  105. """
  106. assert current_thread() == self._parent_thread
  107. self._on_recv_callback = on_recv_callback
  108. self._shell_channel_to_main.on_recv(
  109. IOLoop.current(), _async_in_context(partial(on_recv_callback, None))
  110. )
  111. def set_subshell_aborting(self, subshell_id: str, aborting: bool) -> None:
  112. """Set the aborting flag of the specified subshell."""
  113. with self._lock_cache:
  114. self._cache[subshell_id].aborting = aborting
  115. def subshell_id_from_thread_id(self, thread_id: int) -> str | None:
  116. """Return subshell_id of the specified thread_id.
  117. Raises RuntimeError if thread_id is not the main shell or a subshell.
  118. Only used by %subshell magic so does not have to be fast/cached.
  119. """
  120. with self._lock_cache:
  121. if thread_id == self._parent_thread.ident:
  122. return None
  123. for id, subshell in self._cache.items():
  124. if subshell.ident == thread_id:
  125. return id
  126. msg = f"Thread id {thread_id!r} does not correspond to a subshell of this kernel"
  127. raise RuntimeError(msg)
  128. def _create_subshell(self) -> str:
  129. """Create and start a new subshell thread."""
  130. assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
  131. subshell_id = str(uuid.uuid4())
  132. subshell_thread = SubshellThread(subshell_id, self._context)
  133. with self._lock_cache:
  134. assert subshell_id not in self._cache
  135. self._cache[subshell_id] = subshell_thread
  136. subshell_thread.shell_channel_to_subshell.on_recv(
  137. subshell_thread.io_loop,
  138. _async_in_context(partial(self._on_recv_callback, subshell_id)),
  139. )
  140. subshell_thread.subshell_to_shell_channel.on_recv(
  141. self._shell_channel_io_loop, self._send_on_shell_channel
  142. )
  143. subshell_thread.start()
  144. return subshell_id
  145. def _delete_subshell(self, subshell_id: str) -> None:
  146. """Delete subshell identified by subshell_id.
  147. Raises key error if subshell_id not in cache.
  148. """
  149. assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
  150. with self._lock_cache:
  151. subshell_threwad = self._cache.pop(subshell_id)
  152. self._stop_subshell(subshell_threwad)
  153. def _process_control_request(
  154. self,
  155. request: list[t.Any],
  156. ) -> None:
  157. """Process a control request message received on the control inproc
  158. socket and return the reply. Runs in the shell channel thread.
  159. """
  160. assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
  161. try:
  162. decoded = json.loads(request[0])
  163. type = decoded["type"]
  164. reply: dict[str, t.Any] = {"status": "ok"}
  165. if type == "create":
  166. reply["subshell_id"] = self._create_subshell()
  167. elif type == "delete":
  168. subshell_id = decoded["subshell_id"]
  169. self._delete_subshell(subshell_id)
  170. elif type == "list":
  171. reply["subshell_id"] = self.list_subshell()
  172. else:
  173. msg = f"Unrecognised message type {type!r}"
  174. raise RuntimeError(msg)
  175. except BaseException as err:
  176. reply = {
  177. "status": "error",
  178. "evalue": str(err),
  179. }
  180. # Return the reply to the control thread.
  181. self.control_to_shell_channel.to_socket.send_json(reply)
  182. def _send_on_shell_channel(self, msg) -> None:
  183. assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
  184. self._shell_socket.send_multipart(msg)
  185. def _stop_subshell(self, subshell_thread: SubshellThread) -> None:
  186. """Stop a subshell thread and close all of its resources."""
  187. assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
  188. if subshell_thread.is_alive():
  189. subshell_thread.stop()
  190. subshell_thread.join()