| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261 |
- """Kernel Provisioner Classes"""
- # Copyright (c) Jupyter Development Team.
- # Distributed under the terms of the Modified BSD License.
- import asyncio
- import os
- import pathlib
- import signal
- import sys
- from typing import TYPE_CHECKING, Any
- from ..connect import KernelConnectionInfo, LocalPortCache
- from ..launcher import launch_kernel
- from ..localinterfaces import is_local_ip, local_ips
- from .provisioner_base import KernelProvisionerBase
- class LocalProvisioner(KernelProvisionerBase):
- """
- :class:`LocalProvisioner` is a concrete class of ABC :py:class:`KernelProvisionerBase`
- and is the out-of-box default implementation used when no kernel provisioner is
- specified in the kernel specification (``kernel.json``). It provides functional
- parity to existing applications by launching the kernel locally and using
- :class:`subprocess.Popen` to manage its lifecycle.
- This class is intended to be subclassed for customizing local kernel environments
- and serve as a reference implementation for other custom provisioners.
- """
- process = None
- _exit_future = None
- pid = None
- pgid = None
- ip = None
- ports_cached = False
- cwd = None
- @property
- def has_process(self) -> bool:
- return self.process is not None
- async def poll(self) -> int | None:
- """Poll the provisioner."""
- ret = 0
- if self.process:
- ret = self.process.poll() # type:ignore[unreachable]
- return ret
- async def wait(self) -> int | None:
- """Wait for the provisioner process."""
- ret = 0
- if self.process:
- # Use busy loop at 100ms intervals, polling until the process is
- # not alive. If we find the process is no longer alive, complete
- # its cleanup via the blocking wait(). Callers are responsible for
- # issuing calls to wait() using a timeout (see kill()).
- while await self.poll() is None: # type:ignore[unreachable]
- await asyncio.sleep(0.1)
- # Process is no longer alive, wait and clear
- ret = self.process.wait()
- # Make sure all the fds get closed.
- for attr in ["stdout", "stderr", "stdin"]:
- fid = getattr(self.process, attr)
- if fid:
- fid.close()
- self.process = None # allow has_process to now return False
- return ret
- async def send_signal(self, signum: int) -> None:
- """Sends a signal to the process group of the kernel (this
- usually includes the kernel and any subprocesses spawned by
- the kernel).
- Note that since only SIGTERM is supported on Windows, we will
- check if the desired signal is for interrupt and apply the
- applicable code on Windows in that case.
- """
- if self.process:
- if signum == signal.SIGINT and sys.platform == "win32": # type:ignore[unreachable]
- from ..win_interrupt import send_interrupt
- send_interrupt(self.process.win32_interrupt_event)
- return
- # Prefer process-group over process
- if self.pgid and hasattr(os, "killpg"):
- try:
- os.killpg(self.pgid, signum)
- return
- except OSError:
- pass # We'll retry sending the signal to only the process below
- # If we're here, send the signal to the process and let caller handle exceptions
- self.process.send_signal(signum)
- return
- async def kill(self, restart: bool = False) -> None:
- """Kill the provisioner and optionally restart."""
- if self.process:
- if hasattr(signal, "SIGKILL"): # type:ignore[unreachable]
- # If available, give preference to signalling the process-group over `kill()`.
- try:
- await self.send_signal(signal.SIGKILL)
- return
- except OSError:
- pass
- try:
- self.process.kill()
- except OSError as e:
- LocalProvisioner._tolerate_no_process(e)
- async def terminate(self, restart: bool = False) -> None:
- """Terminate the provisioner and optionally restart."""
- if self.process:
- if hasattr(signal, "SIGTERM"): # type:ignore[unreachable]
- # If available, give preference to signalling the process group over `terminate()`.
- try:
- await self.send_signal(signal.SIGTERM)
- return
- except OSError:
- pass
- try:
- self.process.terminate()
- except OSError as e:
- LocalProvisioner._tolerate_no_process(e)
- @staticmethod
- def _tolerate_no_process(os_error: OSError) -> None:
- # In Windows, we will get an Access Denied error if the process
- # has already terminated. Ignore it.
- if sys.platform == "win32":
- if os_error.winerror != 5:
- err_message = f"Invalid Error, expecting error number to be 5, got {os_error}"
- raise ValueError(err_message)
- # On Unix, we may get an ESRCH error (or ProcessLookupError instance) if
- # the process has already terminated. Ignore it.
- else:
- from errno import ESRCH
- if not isinstance(os_error, ProcessLookupError) or os_error.errno != ESRCH:
- err_message = (
- f"Invalid Error, expecting ProcessLookupError or ESRCH, got {os_error}"
- )
- raise ValueError(err_message)
- async def cleanup(self, restart: bool = False) -> None:
- """Clean up the resources used by the provisioner and optionally restart."""
- if self.ports_cached and not restart:
- # provisioner is about to be destroyed, return cached ports
- lpc = LocalPortCache.instance()
- ports = (
- self.connection_info["shell_port"],
- self.connection_info["iopub_port"],
- self.connection_info["stdin_port"],
- self.connection_info["hb_port"],
- self.connection_info["control_port"],
- )
- for port in ports:
- if TYPE_CHECKING:
- assert isinstance(port, int)
- lpc.return_port(port)
- async def pre_launch(self, **kwargs: Any) -> dict[str, Any]:
- """Perform any steps in preparation for kernel process launch.
- This includes applying additional substitutions to the kernel launch command and env.
- It also includes preparation of launch parameters.
- Returns the updated kwargs.
- """
- # This should be considered temporary until a better division of labor can be defined.
- km = self.parent
- if km:
- if km.transport == "tcp" and not is_local_ip(km.ip):
- msg = (
- "Can only launch a kernel on a local interface. "
- f"This one is not: {km.ip}."
- "Make sure that the '*_address' attributes are "
- "configured properly. "
- f"Currently valid addresses are: {local_ips()}"
- )
- raise RuntimeError(msg)
- # build the Popen cmd
- extra_arguments = kwargs.pop("extra_arguments", [])
- # write connection file / get default ports
- # TODO - change when handshake pattern is adopted
- if km.cache_ports and not self.ports_cached:
- lpc = LocalPortCache.instance()
- km.shell_port = lpc.find_available_port(km.ip)
- km.iopub_port = lpc.find_available_port(km.ip)
- km.stdin_port = lpc.find_available_port(km.ip)
- km.hb_port = lpc.find_available_port(km.ip)
- km.control_port = lpc.find_available_port(km.ip)
- self.ports_cached = True
- if "env" in kwargs:
- jupyter_session = kwargs["env"].get("JPY_SESSION_NAME", "")
- km.write_connection_file(jupyter_session=jupyter_session)
- else:
- km.write_connection_file()
- self.connection_info = km.get_connection_info()
- kernel_cmd = km.format_kernel_cmd(
- extra_arguments=extra_arguments
- ) # This needs to remain here for b/c
- else:
- extra_arguments = kwargs.pop("extra_arguments", [])
- kernel_cmd = self.kernel_spec.argv + extra_arguments
- return await super().pre_launch(cmd=kernel_cmd, **kwargs)
- async def launch_kernel(self, cmd: list[str], **kwargs: Any) -> KernelConnectionInfo:
- """Launch a kernel with a command."""
- scrubbed_kwargs = LocalProvisioner._scrub_kwargs(kwargs)
- self.process = launch_kernel(cmd, **scrubbed_kwargs)
- pgid = None
- if hasattr(os, "getpgid"):
- try:
- pgid = os.getpgid(self.process.pid)
- except OSError:
- pass
- self.pid = self.process.pid
- self.pgid = pgid
- self.cwd = kwargs.get("cwd", pathlib.Path.cwd())
- return self.connection_info
- def resolve_path(self, path_str: str) -> str | None:
- """Resolve path to given file."""
- path = pathlib.Path(path_str).expanduser()
- if not path.is_absolute() and self.cwd:
- path = (pathlib.Path(self.cwd) / path).resolve()
- if path.exists():
- return path.as_posix()
- return None
- @staticmethod
- def _scrub_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
- """Remove any keyword arguments that Popen does not tolerate."""
- keywords_to_scrub: list[str] = ["extra_arguments", "kernel_id"]
- scrubbed_kwargs = kwargs.copy()
- for kw in keywords_to_scrub:
- scrubbed_kwargs.pop(kw, None)
- return scrubbed_kwargs
- async def get_provisioner_info(self) -> dict:
- """Captures the base information necessary for persistence relative to this instance."""
- provisioner_info = await super().get_provisioner_info()
- provisioner_info.update({"pid": self.pid, "pgid": self.pgid, "ip": self.ip})
- return provisioner_info
- async def load_provisioner_info(self, provisioner_info: dict) -> None:
- """Loads the base information necessary for persistence relative to this instance."""
- await super().load_provisioner_info(provisioner_info)
- self.pid = provisioner_info["pid"]
- self.pgid = provisioner_info["pgid"]
- self.ip = provisioner_info["ip"]
|