| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726 |
- """Utilities for connecting to jupyter kernels
- The :class:`ConnectionFileMixin` class in this module encapsulates the logic
- related to writing and reading connections files.
- """
- # Copyright (c) Jupyter Development Team.
- # Distributed under the terms of the Modified BSD License.
- from __future__ import annotations
- import errno
- import glob
- import json
- import os
- import socket
- import stat
- import tempfile
- import warnings
- from getpass import getpass
- from typing import TYPE_CHECKING, Any, Union, cast
- import zmq
- from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write
- from traitlets import Bool, CaselessStrEnum, Instance, Integer, Type, Unicode, observe
- from traitlets.config import LoggingConfigurable, SingletonConfigurable
- from .localinterfaces import localhost
- from .utils import _filefind
- if TYPE_CHECKING:
- from jupyter_client import BlockingKernelClient
- from .session import Session
- # Define custom type for kernel connection info
- KernelConnectionInfo = dict[str, Union[int, str, bytes]]
- def write_connection_file(
- fname: str | None = None,
- shell_port: int = 0,
- iopub_port: int = 0,
- stdin_port: int = 0,
- hb_port: int = 0,
- control_port: int = 0,
- ip: str = "",
- key: bytes = b"",
- transport: str = "tcp",
- signature_scheme: str = "hmac-sha256",
- kernel_name: str = "",
- **kwargs: Any,
- ) -> tuple[str, KernelConnectionInfo]:
- """Generates a JSON config file, including the selection of random ports.
- Parameters
- ----------
- fname : unicode
- The path to the file to write
- shell_port : int, optional
- The port to use for ROUTER (shell) channel.
- iopub_port : int, optional
- The port to use for the SUB channel.
- stdin_port : int, optional
- The port to use for the ROUTER (raw input) channel.
- control_port : int, optional
- The port to use for the ROUTER (control) channel.
- hb_port : int, optional
- The port to use for the heartbeat REP channel.
- ip : str, optional
- The ip address the kernel will bind to.
- key : str, optional
- The Session key used for message authentication.
- signature_scheme : str, optional
- The scheme used for message authentication.
- This has the form 'digest-hash', where 'digest'
- is the scheme used for digests, and 'hash' is the name of the hash function
- used by the digest scheme.
- Currently, 'hmac' is the only supported digest scheme,
- and 'sha256' is the default hash function.
- kernel_name : str, optional
- The name of the kernel currently connected to.
- """
- if not ip:
- ip = localhost()
- # default to temporary connector file
- if not fname:
- fd, fname = tempfile.mkstemp(".json")
- os.close(fd)
- # Find open ports as necessary.
- ports: list[int] = []
- sockets: list[socket.socket] = []
- ports_needed = (
- int(shell_port <= 0)
- + int(iopub_port <= 0)
- + int(stdin_port <= 0)
- + int(control_port <= 0)
- + int(hb_port <= 0)
- )
- if transport == "tcp":
- for _ in range(ports_needed):
- sock = socket.socket()
- # struct.pack('ii', (0,0)) is 8 null bytes
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
- sock.bind((ip, 0))
- sockets.append(sock)
- for sock in sockets:
- port = sock.getsockname()[1]
- sock.close()
- ports.append(port)
- else:
- N = 1
- for _ in range(ports_needed):
- while os.path.exists(f"{ip}-{N!s}"):
- N += 1
- ports.append(N)
- N += 1
- if shell_port <= 0:
- shell_port = ports.pop(0)
- if iopub_port <= 0:
- iopub_port = ports.pop(0)
- if stdin_port <= 0:
- stdin_port = ports.pop(0)
- if control_port <= 0:
- control_port = ports.pop(0)
- if hb_port <= 0:
- hb_port = ports.pop(0)
- cfg: KernelConnectionInfo = {
- "shell_port": shell_port,
- "iopub_port": iopub_port,
- "stdin_port": stdin_port,
- "control_port": control_port,
- "hb_port": hb_port,
- }
- cfg["ip"] = ip
- cfg["key"] = key.decode()
- cfg["transport"] = transport
- cfg["signature_scheme"] = signature_scheme
- cfg["kernel_name"] = kernel_name
- cfg.update(kwargs)
- # Only ever write this file as user read/writeable
- # This would otherwise introduce a vulnerability as a file has secrets
- # which would let others execute arbitrary code as you
- with secure_write(fname) as f:
- f.write(json.dumps(cfg, indent=2))
- if hasattr(stat, "S_ISVTX"):
- # set the sticky bit on the parent directory of the file
- # to ensure only owner can remove it
- runtime_dir = os.path.dirname(fname)
- if runtime_dir:
- permissions = os.stat(runtime_dir).st_mode
- new_permissions = permissions | stat.S_ISVTX
- if new_permissions != permissions:
- try:
- os.chmod(runtime_dir, new_permissions)
- except OSError as e:
- if e.errno == errno.EPERM:
- # suppress permission errors setting sticky bit on runtime_dir,
- # which we may not own.
- pass
- return fname, cfg
- def find_connection_file(
- filename: str = "kernel-*.json",
- path: str | list[str] | None = None,
- profile: str | None = None,
- ) -> str:
- """find a connection file, and return its absolute path.
- The current working directory and optional search path
- will be searched for the file if it is not given by absolute path.
- If the argument does not match an existing file, it will be interpreted as a
- fileglob, and the matching file in the profile's security dir with
- the latest access time will be used.
- Parameters
- ----------
- filename : str
- The connection file or fileglob to search for.
- path : str or list of strs[optional]
- Paths in which to search for connection files.
- Returns
- -------
- str : The absolute path of the connection file.
- """
- if profile is not None:
- warnings.warn(
- "Jupyter has no profiles. profile=%s has been ignored." % profile, stacklevel=2
- )
- if path is None:
- path = [".", jupyter_runtime_dir()]
- if isinstance(path, str):
- path = [path]
- try:
- # first, try explicit name
- return _filefind(filename, path)
- except OSError:
- pass
- # not found by full name
- if "*" in filename:
- # given as a glob already
- pat = filename
- else:
- # accept any substring match
- pat = "*%s*" % filename
- matches = []
- for p in path:
- matches.extend(glob.glob(os.path.join(p, pat)))
- matches = [os.path.abspath(m) for m in matches]
- if not matches:
- msg = f"Could not find {filename!r} in {path!r}"
- raise OSError(msg)
- elif len(matches) == 1:
- return matches[0]
- else:
- # get most recent match, by access time:
- return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
- def tunnel_to_kernel(
- connection_info: str | KernelConnectionInfo,
- sshserver: str,
- sshkey: str | None = None,
- ) -> tuple[Any, ...]:
- """tunnel connections to a kernel via ssh
- This will open five SSH tunnels from localhost on this machine to the
- ports associated with the kernel. They can be either direct
- localhost-localhost tunnels, or if an intermediate server is necessary,
- the kernel must be listening on a public IP.
- Parameters
- ----------
- connection_info : dict or str (path)
- Either a connection dict, or the path to a JSON connection file
- sshserver : str
- The ssh sever to use to tunnel to the kernel. Can be a full
- `user@server:port` string. ssh config aliases are respected.
- sshkey : str [optional]
- Path to file containing ssh key to use for authentication.
- Only necessary if your ssh config does not already associate
- a keyfile with the host.
- Returns
- -------
- (shell, iopub, stdin, hb, control) : ints
- The five ports on localhost that have been forwarded to the kernel.
- """
- from .ssh import tunnel
- if isinstance(connection_info, str):
- # it's a path, unpack it
- with open(connection_info) as f:
- connection_info = json.loads(f.read())
- cf = cast(dict[str, Any], connection_info)
- lports = tunnel.select_random_ports(5)
- rports = (
- cf["shell_port"],
- cf["iopub_port"],
- cf["stdin_port"],
- cf["hb_port"],
- cf["control_port"],
- )
- remote_ip = cf["ip"]
- if tunnel.try_passwordless_ssh(sshserver, sshkey):
- password: bool | str = False
- else:
- password = getpass("SSH Password for %s: " % sshserver)
- for lp, rp in zip(lports, rports, strict=False):
- tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
- return tuple(lports)
- # -----------------------------------------------------------------------------
- # Mixin for classes that work with connection files
- # -----------------------------------------------------------------------------
- channel_socket_types = {
- "hb": zmq.REQ,
- "shell": zmq.DEALER,
- "iopub": zmq.SUB,
- "stdin": zmq.DEALER,
- "control": zmq.DEALER,
- }
- port_names = ["%s_port" % channel for channel in ("shell", "stdin", "iopub", "hb", "control")]
- class ConnectionFileMixin(LoggingConfigurable):
- """Mixin for configurable classes that work with connection files"""
- data_dir: str | Unicode = Unicode()
- def _data_dir_default(self) -> str:
- return jupyter_data_dir()
- # The addresses for the communication channels
- connection_file = Unicode(
- "",
- config=True,
- help="""JSON file in which to store connection info [default: kernel-<pid>.json]
- This file will contain the IP, ports, and authentication key needed to connect
- clients to this kernel. By default, this file will be created in the security dir
- of the current profile, but can be specified by absolute path.
- """,
- )
- _connection_file_written = Bool(False)
- transport = CaselessStrEnum(["tcp", "ipc"], default_value="tcp", config=True)
- kernel_name: str | Unicode = Unicode()
- context = Instance(zmq.Context)
- ip = Unicode(
- config=True,
- help="""Set the kernel\'s IP address [default localhost].
- If the IP address is something other than localhost, then
- Consoles on other machines will be able to connect
- to the Kernel, so be careful!""",
- )
- def _ip_default(self) -> str:
- if self.transport == "ipc":
- if self.connection_file:
- return os.path.splitext(self.connection_file)[0] + "-ipc"
- else:
- return "kernel-ipc"
- else:
- return localhost()
- @observe("ip")
- def _ip_changed(self, change: Any) -> None:
- if change["new"] == "*":
- self.ip = "0.0.0.0" # noqa
- # protected traits
- hb_port = Integer(0, config=True, help="set the heartbeat port [default: random]")
- shell_port = Integer(0, config=True, help="set the shell (ROUTER) port [default: random]")
- iopub_port = Integer(0, config=True, help="set the iopub (PUB) port [default: random]")
- stdin_port = Integer(0, config=True, help="set the stdin (ROUTER) port [default: random]")
- control_port = Integer(0, config=True, help="set the control (ROUTER) port [default: random]")
- # names of the ports with random assignment
- _random_port_names: list[str] | None = None
- @property
- def ports(self) -> list[int]:
- return [getattr(self, name) for name in port_names]
- # The Session to use for communication with the kernel.
- session = Instance("jupyter_client.session.Session")
- def _session_default(self) -> Session:
- from .session import Session
- return Session(parent=self)
- # --------------------------------------------------------------------------
- # Connection and ipc file management
- # --------------------------------------------------------------------------
- def get_connection_info(self, session: bool = False) -> KernelConnectionInfo:
- """Return the connection info as a dict
- Parameters
- ----------
- session : bool [default: False]
- If True, return our session object will be included in the connection info.
- If False (default), the configuration parameters of our session object will be included,
- rather than the session object itself.
- Returns
- -------
- connect_info : dict
- dictionary of connection information.
- """
- info = {
- "transport": self.transport,
- "ip": self.ip,
- "shell_port": self.shell_port,
- "iopub_port": self.iopub_port,
- "stdin_port": self.stdin_port,
- "hb_port": self.hb_port,
- "control_port": self.control_port,
- }
- if session:
- # add *clone* of my session,
- # so that state such as digest_history is not shared.
- info["session"] = self.session.clone()
- else:
- # add session info
- info.update(
- {
- "signature_scheme": self.session.signature_scheme,
- "key": self.session.key,
- }
- )
- return info
- # factory for blocking clients
- blocking_class = Type(klass=object, default_value="jupyter_client.BlockingKernelClient")
- def blocking_client(self) -> BlockingKernelClient:
- """Make a blocking client connected to my kernel"""
- info = self.get_connection_info()
- bc = self.blocking_class(parent=self) # type:ignore[operator]
- bc.load_connection_info(info)
- return bc
- def cleanup_connection_file(self) -> None:
- """Cleanup connection file *if we wrote it*
- Will not raise if the connection file was already removed somehow.
- """
- if self._connection_file_written:
- # cleanup connection files on full shutdown of kernel we started
- self._connection_file_written = False
- try:
- os.remove(self.connection_file)
- except (OSError, AttributeError):
- pass
- def cleanup_ipc_files(self) -> None:
- """Cleanup ipc files if we wrote them."""
- if self.transport != "ipc":
- return
- for port in self.ports:
- ipcfile = "%s-%i" % (self.ip, port)
- try:
- os.remove(ipcfile)
- except OSError:
- pass
- def _record_random_port_names(self) -> None:
- """Records which of the ports are randomly assigned.
- Records on first invocation, if the transport is tcp.
- Does nothing on later invocations."""
- if self.transport != "tcp":
- return
- if self._random_port_names is not None:
- return
- self._random_port_names = []
- for name in port_names:
- if getattr(self, name) <= 0:
- self._random_port_names.append(name)
- def cleanup_random_ports(self) -> None:
- """Forgets randomly assigned port numbers and cleans up the connection file.
- Does nothing if no port numbers have been randomly assigned.
- In particular, does nothing unless the transport is tcp.
- """
- if not self._random_port_names:
- return
- for name in self._random_port_names:
- setattr(self, name, 0)
- self.cleanup_connection_file()
- def write_connection_file(self, **kwargs: Any) -> None:
- """Write connection info to JSON dict in self.connection_file."""
- if self._connection_file_written and os.path.exists(self.connection_file):
- return
- self.connection_file, cfg = write_connection_file(
- self.connection_file,
- transport=self.transport,
- ip=self.ip,
- key=self.session.key,
- stdin_port=self.stdin_port,
- iopub_port=self.iopub_port,
- shell_port=self.shell_port,
- hb_port=self.hb_port,
- control_port=self.control_port,
- signature_scheme=self.session.signature_scheme,
- kernel_name=self.kernel_name,
- **kwargs,
- )
- # write_connection_file also sets default ports:
- self._record_random_port_names()
- for name in port_names:
- setattr(self, name, cfg[name])
- self._connection_file_written = True
- def load_connection_file(self, connection_file: str | None = None) -> None:
- """Load connection info from JSON dict in self.connection_file.
- Parameters
- ----------
- connection_file: unicode, optional
- Path to connection file to load.
- If unspecified, use self.connection_file
- """
- if connection_file is None:
- connection_file = self.connection_file
- self.log.debug("Loading connection file %s", connection_file)
- with open(connection_file) as f:
- info = json.load(f)
- self.load_connection_info(info)
- def load_connection_info(self, info: KernelConnectionInfo) -> None:
- """Load connection info from a dict containing connection info.
- Typically this data comes from a connection file
- and is called by load_connection_file.
- Parameters
- ----------
- info: dict
- Dictionary containing connection_info.
- See the connection_file spec for details.
- """
- self.transport = info.get("transport", self.transport)
- self.ip = info.get("ip", self._ip_default()) # type:ignore[assignment]
- self._record_random_port_names()
- for name in port_names:
- if getattr(self, name) == 0 and name in info:
- # not overridden by config or cl_args
- setattr(self, name, info[name])
- if "key" in info:
- key = info["key"]
- if isinstance(key, str):
- key = key.encode()
- assert isinstance(key, bytes)
- self.session.key = key
- if "signature_scheme" in info:
- self.session.signature_scheme = info["signature_scheme"]
- def _reconcile_connection_info(self, info: KernelConnectionInfo) -> None:
- """Reconciles the connection information returned from the Provisioner.
- Because some provisioners (like derivations of LocalProvisioner) may have already
- written the connection file, this method needs to ensure that, if the connection
- file exists, its contents match that of what was returned by the provisioner. If
- the file does exist and its contents do not match, the file will be replaced with
- the provisioner information (which is considered the truth).
- If the file does not exist, the connection information in 'info' is loaded into the
- KernelManager and written to the file.
- """
- # Prevent over-writing a file that has already been written with the same
- # info. This is to prevent a race condition where the process has
- # already been launched but has not yet read the connection file - as is
- # the case with LocalProvisioners.
- file_exists: bool = False
- if os.path.exists(self.connection_file):
- with open(self.connection_file) as f:
- file_info = json.load(f)
- # Prior to the following comparison, we need to adjust the value of "key" to
- # be bytes, otherwise the comparison below will fail.
- file_info["key"] = file_info["key"].encode()
- if not self._equal_connections(info, file_info):
- os.remove(self.connection_file) # Contents mismatch - remove the file
- self._connection_file_written = False
- else:
- file_exists = True
- if not file_exists:
- # Load the connection info and write out file, clearing existing
- # port-based attributes so they will be reloaded
- for name in port_names:
- setattr(self, name, 0)
- self.load_connection_info(info)
- self.write_connection_file()
- # Ensure what is in KernelManager is what we expect.
- km_info = self.get_connection_info()
- if not self._equal_connections(info, km_info):
- msg = (
- "KernelManager's connection information already exists and does not match "
- "the expected values returned from provisioner!"
- )
- raise ValueError(msg)
- @staticmethod
- def _equal_connections(conn1: KernelConnectionInfo, conn2: KernelConnectionInfo) -> bool:
- """Compares pertinent keys of connection info data. Returns True if equivalent, False otherwise."""
- pertinent_keys = [
- "key",
- "ip",
- "stdin_port",
- "iopub_port",
- "shell_port",
- "control_port",
- "hb_port",
- "transport",
- "signature_scheme",
- ]
- return all(conn1.get(key) == conn2.get(key) for key in pertinent_keys)
- # --------------------------------------------------------------------------
- # Creating connected sockets
- # --------------------------------------------------------------------------
- def _make_url(self, channel: str) -> str:
- """Make a ZeroMQ URL for a given channel."""
- transport = self.transport
- ip = self.ip
- port = getattr(self, "%s_port" % channel)
- if transport == "tcp":
- return "tcp://%s:%i" % (ip, port)
- else:
- return f"{transport}://{ip}-{port}"
- def _create_connected_socket(
- self, channel: str, identity: bytes | None = None
- ) -> zmq.sugar.socket.Socket:
- """Create a zmq Socket and connect it to the kernel."""
- url = self._make_url(channel)
- socket_type = channel_socket_types[channel]
- self.log.debug("Connecting to: %s", url)
- sock = self.context.socket(socket_type)
- # set linger to 1s to prevent hangs at exit
- sock.linger = 1000
- if identity:
- sock.identity = identity
- sock.connect(url)
- return sock
- def connect_iopub(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
- """return zmq Socket connected to the IOPub channel"""
- sock = self._create_connected_socket("iopub", identity=identity)
- sock.setsockopt(zmq.SUBSCRIBE, b"")
- return sock
- def connect_shell(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
- """return zmq Socket connected to the Shell channel"""
- return self._create_connected_socket("shell", identity=identity)
- def connect_stdin(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
- """return zmq Socket connected to the StdIn channel"""
- return self._create_connected_socket("stdin", identity=identity)
- def connect_hb(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
- """return zmq Socket connected to the Heartbeat channel"""
- return self._create_connected_socket("hb", identity=identity)
- def connect_control(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
- """return zmq Socket connected to the Control channel"""
- return self._create_connected_socket("control", identity=identity)
- class LocalPortCache(SingletonConfigurable):
- """
- Used to keep track of local ports in order to prevent race conditions that
- can occur between port acquisition and usage by the kernel. All locally-
- provisioned kernels should use this mechanism to limit the possibility of
- race conditions. Note that this does not preclude other applications from
- acquiring a cached but unused port, thereby re-introducing the issue this
- class is attempting to resolve (minimize).
- See: https://github.com/jupyter/jupyter_client/issues/487
- """
- def __init__(self, **kwargs: Any) -> None:
- super().__init__(**kwargs)
- self.currently_used_ports: set[int] = set()
- def find_available_port(self, ip: str) -> int:
- while True:
- tmp_sock = socket.socket()
- tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
- tmp_sock.bind((ip, 0))
- port = tmp_sock.getsockname()[1]
- tmp_sock.close()
- # This is a workaround for https://github.com/jupyter/jupyter_client/issues/487
- # We prevent two kernels to have the same ports.
- if port not in self.currently_used_ports:
- self.currently_used_ports.add(port)
- return port
- def return_port(self, port: int) -> None:
- if port in self.currently_used_ports: # Tolerate uncached ports
- self.currently_used_ports.remove(port)
- __all__ = [
- "KernelConnectionInfo",
- "LocalPortCache",
- "find_connection_file",
- "tunnel_to_kernel",
- "write_connection_file",
- ]
|