local_provisioner.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. """Kernel Provisioner Classes"""
  2. # Copyright (c) Jupyter Development Team.
  3. # Distributed under the terms of the Modified BSD License.
  4. import asyncio
  5. import os
  6. import pathlib
  7. import signal
  8. import sys
  9. from typing import TYPE_CHECKING, Any
  10. from ..connect import KernelConnectionInfo, LocalPortCache
  11. from ..launcher import launch_kernel
  12. from ..localinterfaces import is_local_ip, local_ips
  13. from .provisioner_base import KernelProvisionerBase
  14. class LocalProvisioner(KernelProvisionerBase):
  15. """
  16. :class:`LocalProvisioner` is a concrete class of ABC :py:class:`KernelProvisionerBase`
  17. and is the out-of-box default implementation used when no kernel provisioner is
  18. specified in the kernel specification (``kernel.json``). It provides functional
  19. parity to existing applications by launching the kernel locally and using
  20. :class:`subprocess.Popen` to manage its lifecycle.
  21. This class is intended to be subclassed for customizing local kernel environments
  22. and serve as a reference implementation for other custom provisioners.
  23. """
  24. process = None
  25. _exit_future = None
  26. pid = None
  27. pgid = None
  28. ip = None
  29. ports_cached = False
  30. cwd = None
  31. @property
  32. def has_process(self) -> bool:
  33. return self.process is not None
  34. async def poll(self) -> int | None:
  35. """Poll the provisioner."""
  36. ret = 0
  37. if self.process:
  38. ret = self.process.poll() # type:ignore[unreachable]
  39. return ret
  40. async def wait(self) -> int | None:
  41. """Wait for the provisioner process."""
  42. ret = 0
  43. if self.process:
  44. # Use busy loop at 100ms intervals, polling until the process is
  45. # not alive. If we find the process is no longer alive, complete
  46. # its cleanup via the blocking wait(). Callers are responsible for
  47. # issuing calls to wait() using a timeout (see kill()).
  48. while await self.poll() is None: # type:ignore[unreachable]
  49. await asyncio.sleep(0.1)
  50. # Process is no longer alive, wait and clear
  51. ret = self.process.wait()
  52. # Make sure all the fds get closed.
  53. for attr in ["stdout", "stderr", "stdin"]:
  54. fid = getattr(self.process, attr)
  55. if fid:
  56. fid.close()
  57. self.process = None # allow has_process to now return False
  58. return ret
  59. async def send_signal(self, signum: int) -> None:
  60. """Sends a signal to the process group of the kernel (this
  61. usually includes the kernel and any subprocesses spawned by
  62. the kernel).
  63. Note that since only SIGTERM is supported on Windows, we will
  64. check if the desired signal is for interrupt and apply the
  65. applicable code on Windows in that case.
  66. """
  67. if self.process:
  68. if signum == signal.SIGINT and sys.platform == "win32": # type:ignore[unreachable]
  69. from ..win_interrupt import send_interrupt
  70. send_interrupt(self.process.win32_interrupt_event)
  71. return
  72. # Prefer process-group over process
  73. if self.pgid and hasattr(os, "killpg"):
  74. try:
  75. os.killpg(self.pgid, signum)
  76. return
  77. except OSError:
  78. pass # We'll retry sending the signal to only the process below
  79. # If we're here, send the signal to the process and let caller handle exceptions
  80. self.process.send_signal(signum)
  81. return
  82. async def kill(self, restart: bool = False) -> None:
  83. """Kill the provisioner and optionally restart."""
  84. if self.process:
  85. if hasattr(signal, "SIGKILL"): # type:ignore[unreachable]
  86. # If available, give preference to signalling the process-group over `kill()`.
  87. try:
  88. await self.send_signal(signal.SIGKILL)
  89. return
  90. except OSError:
  91. pass
  92. try:
  93. self.process.kill()
  94. except OSError as e:
  95. LocalProvisioner._tolerate_no_process(e)
  96. async def terminate(self, restart: bool = False) -> None:
  97. """Terminate the provisioner and optionally restart."""
  98. if self.process:
  99. if hasattr(signal, "SIGTERM"): # type:ignore[unreachable]
  100. # If available, give preference to signalling the process group over `terminate()`.
  101. try:
  102. await self.send_signal(signal.SIGTERM)
  103. return
  104. except OSError:
  105. pass
  106. try:
  107. self.process.terminate()
  108. except OSError as e:
  109. LocalProvisioner._tolerate_no_process(e)
  110. @staticmethod
  111. def _tolerate_no_process(os_error: OSError) -> None:
  112. # In Windows, we will get an Access Denied error if the process
  113. # has already terminated. Ignore it.
  114. if sys.platform == "win32":
  115. if os_error.winerror != 5:
  116. err_message = f"Invalid Error, expecting error number to be 5, got {os_error}"
  117. raise ValueError(err_message)
  118. # On Unix, we may get an ESRCH error (or ProcessLookupError instance) if
  119. # the process has already terminated. Ignore it.
  120. else:
  121. from errno import ESRCH
  122. if not isinstance(os_error, ProcessLookupError) or os_error.errno != ESRCH:
  123. err_message = (
  124. f"Invalid Error, expecting ProcessLookupError or ESRCH, got {os_error}"
  125. )
  126. raise ValueError(err_message)
  127. async def cleanup(self, restart: bool = False) -> None:
  128. """Clean up the resources used by the provisioner and optionally restart."""
  129. if self.ports_cached and not restart:
  130. # provisioner is about to be destroyed, return cached ports
  131. lpc = LocalPortCache.instance()
  132. ports = (
  133. self.connection_info["shell_port"],
  134. self.connection_info["iopub_port"],
  135. self.connection_info["stdin_port"],
  136. self.connection_info["hb_port"],
  137. self.connection_info["control_port"],
  138. )
  139. for port in ports:
  140. if TYPE_CHECKING:
  141. assert isinstance(port, int)
  142. lpc.return_port(port)
  143. async def pre_launch(self, **kwargs: Any) -> dict[str, Any]:
  144. """Perform any steps in preparation for kernel process launch.
  145. This includes applying additional substitutions to the kernel launch command and env.
  146. It also includes preparation of launch parameters.
  147. Returns the updated kwargs.
  148. """
  149. # This should be considered temporary until a better division of labor can be defined.
  150. km = self.parent
  151. if km:
  152. if km.transport == "tcp" and not is_local_ip(km.ip):
  153. msg = (
  154. "Can only launch a kernel on a local interface. "
  155. f"This one is not: {km.ip}."
  156. "Make sure that the '*_address' attributes are "
  157. "configured properly. "
  158. f"Currently valid addresses are: {local_ips()}"
  159. )
  160. raise RuntimeError(msg)
  161. # build the Popen cmd
  162. extra_arguments = kwargs.pop("extra_arguments", [])
  163. # write connection file / get default ports
  164. # TODO - change when handshake pattern is adopted
  165. if km.cache_ports and not self.ports_cached:
  166. lpc = LocalPortCache.instance()
  167. km.shell_port = lpc.find_available_port(km.ip)
  168. km.iopub_port = lpc.find_available_port(km.ip)
  169. km.stdin_port = lpc.find_available_port(km.ip)
  170. km.hb_port = lpc.find_available_port(km.ip)
  171. km.control_port = lpc.find_available_port(km.ip)
  172. self.ports_cached = True
  173. if "env" in kwargs:
  174. jupyter_session = kwargs["env"].get("JPY_SESSION_NAME", "")
  175. km.write_connection_file(jupyter_session=jupyter_session)
  176. else:
  177. km.write_connection_file()
  178. self.connection_info = km.get_connection_info()
  179. kernel_cmd = km.format_kernel_cmd(
  180. extra_arguments=extra_arguments
  181. ) # This needs to remain here for b/c
  182. else:
  183. extra_arguments = kwargs.pop("extra_arguments", [])
  184. kernel_cmd = self.kernel_spec.argv + extra_arguments
  185. return await super().pre_launch(cmd=kernel_cmd, **kwargs)
  186. async def launch_kernel(self, cmd: list[str], **kwargs: Any) -> KernelConnectionInfo:
  187. """Launch a kernel with a command."""
  188. scrubbed_kwargs = LocalProvisioner._scrub_kwargs(kwargs)
  189. self.process = launch_kernel(cmd, **scrubbed_kwargs)
  190. pgid = None
  191. if hasattr(os, "getpgid"):
  192. try:
  193. pgid = os.getpgid(self.process.pid)
  194. except OSError:
  195. pass
  196. self.pid = self.process.pid
  197. self.pgid = pgid
  198. self.cwd = kwargs.get("cwd", pathlib.Path.cwd())
  199. return self.connection_info
  200. def resolve_path(self, path_str: str) -> str | None:
  201. """Resolve path to given file."""
  202. path = pathlib.Path(path_str).expanduser()
  203. if not path.is_absolute() and self.cwd:
  204. path = (pathlib.Path(self.cwd) / path).resolve()
  205. if path.exists():
  206. return path.as_posix()
  207. return None
  208. @staticmethod
  209. def _scrub_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
  210. """Remove any keyword arguments that Popen does not tolerate."""
  211. keywords_to_scrub: list[str] = ["extra_arguments", "kernel_id"]
  212. scrubbed_kwargs = kwargs.copy()
  213. for kw in keywords_to_scrub:
  214. scrubbed_kwargs.pop(kw, None)
  215. return scrubbed_kwargs
  216. async def get_provisioner_info(self) -> dict:
  217. """Captures the base information necessary for persistence relative to this instance."""
  218. provisioner_info = await super().get_provisioner_info()
  219. provisioner_info.update({"pid": self.pid, "pgid": self.pgid, "ip": self.ip})
  220. return provisioner_info
  221. async def load_provisioner_info(self, provisioner_info: dict) -> None:
  222. """Loads the base information necessary for persistence relative to this instance."""
  223. await super().load_provisioner_info(provisioner_info)
  224. self.pid = provisioner_info["pid"]
  225. self.pgid = provisioner_info["pgid"]
  226. self.ip = provisioner_info["ip"]