| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- """Manager of subshells in a kernel."""
- from __future__ import annotations
- import asyncio
- import json
- import typing as t
- import uuid
- from functools import partial
- from threading import Lock, current_thread
- import zmq
- from tornado.ioloop import IOLoop
- from .socket_pair import SocketPair
- from .subshell import SubshellThread
- from .thread import SHELL_CHANNEL_THREAD_NAME
- from .utils import _async_in_context
- class SubshellManager:
- """A manager of subshells.
- Controls the lifetimes of subshell threads and their associated ZMQ sockets and
- streams. Runs mostly in the shell channel thread.
- Care needed with threadsafe access here. All write access to the cache occurs in
- the shell channel thread so there is only ever one write access at any one time.
- Reading of cache information can be performed by other threads, so all reads are
- protected by a lock so that they are atomic.
- Sending reply messages via the shell_socket is wrapped by another lock to protect
- against multiple subshells attempting to send at the same time.
- .. versionadded:: 7
- """
- def __init__(
- self,
- context: zmq.Context[t.Any],
- shell_channel_io_loop: IOLoop,
- shell_socket: zmq.Socket[t.Any],
- ):
- """Initialize the subshell manager."""
- self._parent_thread = current_thread()
- self._context: zmq.Context[t.Any] = context
- self._shell_channel_io_loop = shell_channel_io_loop
- self._shell_socket = shell_socket
- self._cache: dict[str, SubshellThread] = {}
- self._lock_cache = Lock() # Sync lock across threads when accessing cache.
- # Inproc socket pair for communication from control thread to shell channel thread,
- # such as for create_subshell_request messages. Reply messages are returned straight away.
- self.control_to_shell_channel = SocketPair(self._context, "control")
- self.control_to_shell_channel.on_recv(
- self._shell_channel_io_loop, self._process_control_request, copy=True
- )
- # Inproc socket pair for communication from shell channel thread to main thread,
- # such as for execute_request messages.
- self._shell_channel_to_main = SocketPair(self._context, "main")
- # Inproc socket pair for communication from main thread to shell channel thread.
- # such as for execute_reply messages.
- self._main_to_shell_channel = SocketPair(self._context, "main-reverse")
- self._main_to_shell_channel.on_recv(
- self._shell_channel_io_loop, self._send_on_shell_channel
- )
- def close(self) -> None:
- """Stop all subshells and close all resources."""
- assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
- with self._lock_cache:
- while True:
- try:
- _, subshell_thread = self._cache.popitem()
- except KeyError:
- break
- self._stop_subshell(subshell_thread)
- self.control_to_shell_channel.close()
- self._main_to_shell_channel.close()
- self._shell_channel_to_main.close()
- def get_shell_channel_to_subshell_pair(self, subshell_id: str | None) -> SocketPair:
- """Return the inproc socket pair used to send messages from the shell channel
- to a particular subshell or main shell."""
- if subshell_id is None:
- return self._shell_channel_to_main
- with self._lock_cache:
- return self._cache[subshell_id].shell_channel_to_subshell
- def get_subshell_to_shell_channel_socket(self, subshell_id: str | None) -> zmq.Socket[t.Any]:
- """Return the socket used by a particular subshell or main shell to send
- messages to the shell channel.
- """
- if subshell_id is None:
- return self._main_to_shell_channel.from_socket
- with self._lock_cache:
- return self._cache[subshell_id].subshell_to_shell_channel.from_socket
- def get_shell_channel_to_subshell_socket(self, subshell_id: str | None) -> zmq.Socket[t.Any]:
- """Return the socket used by the shell channel to send messages to a particular
- subshell or main shell.
- """
- return self.get_shell_channel_to_subshell_pair(subshell_id).from_socket
- def get_subshell_aborting(self, subshell_id: str) -> bool:
- """Get the boolean aborting flag of the specified subshell."""
- with self._lock_cache:
- return self._cache[subshell_id].aborting
- def get_subshell_asyncio_lock(self, subshell_id: str) -> asyncio.Lock:
- """Return the asyncio lock belonging to the specified subshell."""
- with self._lock_cache:
- return self._cache[subshell_id].asyncio_lock
- def list_subshell(self) -> list[str]:
- """Return list of current subshell ids.
- Can be called by any subshell using %subshell magic.
- """
- with self._lock_cache:
- return list(self._cache)
- def set_on_recv_callback(self, on_recv_callback):
- """Set the callback used by the main shell and all subshells to receive
- messages sent from the shell channel thread.
- """
- assert current_thread() == self._parent_thread
- self._on_recv_callback = on_recv_callback
- self._shell_channel_to_main.on_recv(
- IOLoop.current(), _async_in_context(partial(on_recv_callback, None))
- )
- def set_subshell_aborting(self, subshell_id: str, aborting: bool) -> None:
- """Set the aborting flag of the specified subshell."""
- with self._lock_cache:
- self._cache[subshell_id].aborting = aborting
- def subshell_id_from_thread_id(self, thread_id: int) -> str | None:
- """Return subshell_id of the specified thread_id.
- Raises RuntimeError if thread_id is not the main shell or a subshell.
- Only used by %subshell magic so does not have to be fast/cached.
- """
- with self._lock_cache:
- if thread_id == self._parent_thread.ident:
- return None
- for id, subshell in self._cache.items():
- if subshell.ident == thread_id:
- return id
- msg = f"Thread id {thread_id!r} does not correspond to a subshell of this kernel"
- raise RuntimeError(msg)
- def _create_subshell(self) -> str:
- """Create and start a new subshell thread."""
- assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
- subshell_id = str(uuid.uuid4())
- subshell_thread = SubshellThread(subshell_id, self._context)
- with self._lock_cache:
- assert subshell_id not in self._cache
- self._cache[subshell_id] = subshell_thread
- subshell_thread.shell_channel_to_subshell.on_recv(
- subshell_thread.io_loop,
- _async_in_context(partial(self._on_recv_callback, subshell_id)),
- )
- subshell_thread.subshell_to_shell_channel.on_recv(
- self._shell_channel_io_loop, self._send_on_shell_channel
- )
- subshell_thread.start()
- return subshell_id
- def _delete_subshell(self, subshell_id: str) -> None:
- """Delete subshell identified by subshell_id.
- Raises key error if subshell_id not in cache.
- """
- assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
- with self._lock_cache:
- subshell_threwad = self._cache.pop(subshell_id)
- self._stop_subshell(subshell_threwad)
- def _process_control_request(
- self,
- request: list[t.Any],
- ) -> None:
- """Process a control request message received on the control inproc
- socket and return the reply. Runs in the shell channel thread.
- """
- assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
- try:
- decoded = json.loads(request[0])
- type = decoded["type"]
- reply: dict[str, t.Any] = {"status": "ok"}
- if type == "create":
- reply["subshell_id"] = self._create_subshell()
- elif type == "delete":
- subshell_id = decoded["subshell_id"]
- self._delete_subshell(subshell_id)
- elif type == "list":
- reply["subshell_id"] = self.list_subshell()
- else:
- msg = f"Unrecognised message type {type!r}"
- raise RuntimeError(msg)
- except BaseException as err:
- reply = {
- "status": "error",
- "evalue": str(err),
- }
- # Return the reply to the control thread.
- self.control_to_shell_channel.to_socket.send_json(reply)
- def _send_on_shell_channel(self, msg) -> None:
- assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
- self._shell_socket.send_multipart(msg)
- def _stop_subshell(self, subshell_thread: SubshellThread) -> None:
- """Stop a subshell thread and close all of its resources."""
- assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
- if subshell_thread.is_alive():
- subshell_thread.stop()
- subshell_thread.join()
|