tunnel.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. """Basic ssh tunnel utilities, and convenience functions for tunneling
  2. zeromq connections.
  3. """
  4. # Copyright (C) 2010-2011 IPython Development Team
  5. # Copyright (C) 2011- PyZMQ Developers
  6. #
  7. # Redistributed from IPython under the terms of the BSD License.
  8. from __future__ import annotations
  9. import atexit
  10. import os
  11. import re
  12. import signal
  13. import socket
  14. import sys
  15. import warnings
  16. from getpass import getpass, getuser
  17. from multiprocessing import Process
  18. from types import ModuleType
  19. from typing import Any, cast
  20. try:
  21. with warnings.catch_warnings():
  22. warnings.simplefilter("ignore", DeprecationWarning)
  23. import paramiko
  24. SSHException = paramiko.ssh_exception.SSHException
  25. except ImportError:
  26. paramiko = None # type:ignore[assignment]
  27. class SSHException(Exception): # type:ignore[no-redef] # noqa
  28. pass
  29. else:
  30. from .forward import forward_tunnel
  31. pexpect: ModuleType | None
  32. try:
  33. import pexpect
  34. except ImportError:
  35. pexpect = None
  36. def select_random_ports(n: int) -> list[int]:
  37. """Select and return n random ports that are available."""
  38. ports = []
  39. sockets = []
  40. for _ in range(n):
  41. sock = socket.socket()
  42. sock.bind(("", 0))
  43. ports.append(sock.getsockname()[1])
  44. sockets.append(sock)
  45. for sock in sockets:
  46. sock.close()
  47. return ports
  48. # -----------------------------------------------------------------------------
  49. # Check for passwordless login
  50. # -----------------------------------------------------------------------------
  51. _password_pat = re.compile((rb"pass(word|phrase):"), re.IGNORECASE)
  52. def try_passwordless_ssh(server: str, keyfile: str | None, paramiko: Any = None) -> Any:
  53. """Attempt to make an ssh connection without a password.
  54. This is mainly used for requiring password input only once
  55. when many tunnels may be connected to the same server.
  56. If paramiko is None, the default for the platform is chosen.
  57. """
  58. if paramiko is None:
  59. paramiko = sys.platform == "win32"
  60. f = _try_passwordless_paramiko if paramiko else _try_passwordless_openssh
  61. return f(server, keyfile)
  62. def _try_passwordless_openssh(server: str, keyfile: str | None) -> bool:
  63. """Try passwordless login with shell ssh command."""
  64. if pexpect is None:
  65. msg = "pexpect unavailable, use paramiko"
  66. raise ImportError(msg)
  67. cmd = "ssh -f " + server
  68. if keyfile:
  69. cmd += " -i " + keyfile
  70. cmd += " exit"
  71. # pop SSH_ASKPASS from env
  72. env = os.environ.copy()
  73. env.pop("SSH_ASKPASS", None)
  74. ssh_newkey = "Are you sure you want to continue connecting"
  75. p = pexpect.spawn(cmd, env=env)
  76. while True:
  77. try:
  78. i = p.expect([ssh_newkey, _password_pat], timeout=0.1)
  79. if i == 0:
  80. msg = "The authenticity of the host can't be established."
  81. raise SSHException(msg)
  82. except pexpect.TIMEOUT:
  83. continue
  84. except pexpect.EOF:
  85. return True
  86. else:
  87. return False
  88. def _try_passwordless_paramiko(server: str, keyfile: str | None) -> bool:
  89. """Try passwordless login with paramiko."""
  90. if paramiko is None:
  91. msg = "Paramiko unavailable, " # type:ignore[unreachable]
  92. if sys.platform == "win32":
  93. msg += "Paramiko is required for ssh tunneled connections on Windows."
  94. else:
  95. msg += "use OpenSSH."
  96. raise ImportError(msg)
  97. username, server, port = _split_server(server)
  98. client = paramiko.SSHClient()
  99. client.load_system_host_keys()
  100. client.set_missing_host_key_policy(paramiko.WarningPolicy())
  101. try:
  102. client.connect(server, port, username=username, key_filename=keyfile, look_for_keys=True)
  103. except paramiko.AuthenticationException:
  104. return False
  105. else:
  106. client.close()
  107. return True
  108. def tunnel_connection(
  109. socket: socket.socket,
  110. addr: str,
  111. server: str,
  112. keyfile: str | None = None,
  113. password: str | None = None,
  114. paramiko: Any = None,
  115. timeout: int = 60,
  116. ) -> int:
  117. """Connect a socket to an address via an ssh tunnel.
  118. This is a wrapper for socket.connect(addr), when addr is not accessible
  119. from the local machine. It simply creates an ssh tunnel using the remaining args,
  120. and calls socket.connect('tcp://localhost:lport') where lport is the randomly
  121. selected local port of the tunnel.
  122. """
  123. new_url, tunnel = open_tunnel(
  124. addr,
  125. server,
  126. keyfile=keyfile,
  127. password=password,
  128. paramiko=paramiko,
  129. timeout=timeout,
  130. )
  131. socket.connect(new_url)
  132. return tunnel
  133. def open_tunnel(
  134. addr: str,
  135. server: str,
  136. keyfile: str | None = None,
  137. password: str | None = None,
  138. paramiko: Any = None,
  139. timeout: int = 60,
  140. ) -> tuple[str, int]:
  141. """Open a tunneled connection from a 0MQ url.
  142. For use inside tunnel_connection.
  143. Returns
  144. -------
  145. (url, tunnel) : (str, object)
  146. The 0MQ url that has been forwarded, and the tunnel object
  147. """
  148. lport = select_random_ports(1)[0]
  149. _, addr = addr.split("://")
  150. ip, rport = addr.split(":")
  151. rport_int = int(rport)
  152. paramiko = sys.platform == "win32" if paramiko is None else paramiko_tunnel
  153. tunnelf = paramiko_tunnel if paramiko else openssh_tunnel
  154. tunnel = tunnelf(
  155. lport,
  156. rport_int,
  157. server,
  158. remoteip=ip,
  159. keyfile=keyfile,
  160. password=password,
  161. timeout=timeout,
  162. )
  163. return "tcp://127.0.0.1:%i" % lport, cast(int, tunnel)
  164. def openssh_tunnel(
  165. lport: int,
  166. rport: int,
  167. server: str,
  168. remoteip: str = "127.0.0.1",
  169. keyfile: str | None = None,
  170. password: str | None | bool = None,
  171. timeout: int = 60,
  172. ) -> int:
  173. """Create an ssh tunnel using command-line ssh that connects port lport
  174. on this machine to localhost:rport on server. The tunnel
  175. will automatically close when not in use, remaining open
  176. for a minimum of timeout seconds for an initial connection.
  177. This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`,
  178. as seen from `server`.
  179. keyfile and password may be specified, but ssh config is checked for defaults.
  180. Parameters
  181. ----------
  182. lport : int
  183. local port for connecting to the tunnel from this machine.
  184. rport : int
  185. port on the remote machine to connect to.
  186. server : str
  187. The ssh server to connect to. The full ssh server string will be parsed.
  188. user@server:port
  189. remoteip : str [Default: 127.0.0.1]
  190. The remote ip, specifying the destination of the tunnel.
  191. Default is localhost, which means that the tunnel would redirect
  192. localhost:lport on this machine to localhost:rport on the *server*.
  193. keyfile : str; path to public key file
  194. This specifies a key to be used in ssh login, default None.
  195. Regular default ssh keys will be used without specifying this argument.
  196. password : str;
  197. Your ssh password to the ssh server. Note that if this is left None,
  198. you will be prompted for it if passwordless key based login is unavailable.
  199. timeout : int [default: 60]
  200. The time (in seconds) after which no activity will result in the tunnel
  201. closing. This prevents orphaned tunnels from running forever.
  202. """
  203. if pexpect is None:
  204. msg = "pexpect unavailable, use paramiko_tunnel"
  205. raise ImportError(msg)
  206. ssh = "ssh "
  207. if keyfile:
  208. ssh += "-i " + keyfile
  209. if ":" in server:
  210. server, port = server.split(":")
  211. ssh += " -p %s" % port
  212. cmd = f"{ssh} -O check {server}"
  213. (output, exitstatus) = pexpect.run(cmd, withexitstatus=True)
  214. if not exitstatus:
  215. pid = int(output[output.find(b"(pid=") + 5 : output.find(b")")])
  216. cmd = "%s -O forward -L 127.0.0.1:%i:%s:%i %s" % (
  217. ssh,
  218. lport,
  219. remoteip,
  220. rport,
  221. server,
  222. )
  223. (output, exitstatus) = pexpect.run(cmd, withexitstatus=True)
  224. if not exitstatus:
  225. atexit.register(_stop_tunnel, cmd.replace("-O forward", "-O cancel", 1))
  226. return pid
  227. cmd = "%s -f -S none -L 127.0.0.1:%i:%s:%i %s sleep %i" % (
  228. ssh,
  229. lport,
  230. remoteip,
  231. rport,
  232. server,
  233. timeout,
  234. )
  235. # pop SSH_ASKPASS from env
  236. env = os.environ.copy()
  237. env.pop("SSH_ASKPASS", None)
  238. ssh_newkey = "Are you sure you want to continue connecting"
  239. tunnel = pexpect.spawn(cmd, env=env)
  240. failed = False
  241. while True:
  242. try:
  243. i = tunnel.expect([ssh_newkey, _password_pat], timeout=0.1)
  244. if i == 0:
  245. msg = "The authenticity of the host can't be established."
  246. raise SSHException(msg)
  247. except pexpect.TIMEOUT:
  248. continue
  249. except pexpect.EOF as e:
  250. tunnel.wait()
  251. if tunnel.exitstatus:
  252. raise RuntimeError("tunnel '%s' failed to start" % (cmd)) from e
  253. else:
  254. return tunnel.pid
  255. else:
  256. if failed:
  257. warnings.warn("Password rejected, try again", stacklevel=2)
  258. password = None
  259. if password is None:
  260. password = getpass("%s's password: " % (server))
  261. tunnel.sendline(password)
  262. failed = True
  263. def _stop_tunnel(cmd: Any) -> None:
  264. assert pexpect is not None
  265. pexpect.run(cmd)
  266. def _split_server(server: str) -> tuple[str, str, int]:
  267. if "@" in server:
  268. username, server = server.split("@", 1)
  269. else:
  270. username = getuser()
  271. if ":" in server:
  272. server, port_str = server.split(":")
  273. port = int(port_str)
  274. else:
  275. port = 22
  276. return username, server, port
  277. def paramiko_tunnel(
  278. lport: int,
  279. rport: int,
  280. server: str,
  281. remoteip: str = "127.0.0.1",
  282. keyfile: str | None = None,
  283. password: str | None = None,
  284. timeout: float = 60,
  285. ) -> Process:
  286. """launch a tunner with paramiko in a subprocess. This should only be used
  287. when shell ssh is unavailable (e.g. Windows).
  288. This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`,
  289. as seen from `server`.
  290. If you are familiar with ssh tunnels, this creates the tunnel:
  291. ssh server -L localhost:lport:remoteip:rport
  292. keyfile and password may be specified, but ssh config is checked for defaults.
  293. Parameters
  294. ----------
  295. lport : int
  296. local port for connecting to the tunnel from this machine.
  297. rport : int
  298. port on the remote machine to connect to.
  299. server : str
  300. The ssh server to connect to. The full ssh server string will be parsed.
  301. user@server:port
  302. remoteip : str [Default: 127.0.0.1]
  303. The remote ip, specifying the destination of the tunnel.
  304. Default is localhost, which means that the tunnel would redirect
  305. localhost:lport on this machine to localhost:rport on the *server*.
  306. keyfile : str; path to public key file
  307. This specifies a key to be used in ssh login, default None.
  308. Regular default ssh keys will be used without specifying this argument.
  309. password : str;
  310. Your ssh password to the ssh server. Note that if this is left None,
  311. you will be prompted for it if passwordless key based login is unavailable.
  312. timeout : int [default: 60]
  313. The time (in seconds) after which no activity will result in the tunnel
  314. closing. This prevents orphaned tunnels from running forever.
  315. """
  316. if paramiko is None:
  317. msg = "Paramiko not available" # type:ignore[unreachable]
  318. raise ImportError(msg)
  319. if password is None and not _try_passwordless_paramiko(server, keyfile):
  320. password = getpass("%s's password: " % (server))
  321. p = Process(
  322. target=_paramiko_tunnel,
  323. args=(lport, rport, server, remoteip),
  324. kwargs={"keyfile": keyfile, "password": password},
  325. )
  326. p.daemon = True
  327. p.start()
  328. return p
  329. def _paramiko_tunnel(
  330. lport: int,
  331. rport: int,
  332. server: str,
  333. remoteip: str,
  334. keyfile: str | None = None,
  335. password: str | None = None,
  336. ) -> None:
  337. """Function for actually starting a paramiko tunnel, to be passed
  338. to multiprocessing.Process(target=this), and not called directly.
  339. """
  340. username, server, port = _split_server(server)
  341. client = paramiko.SSHClient()
  342. client.load_system_host_keys()
  343. client.set_missing_host_key_policy(paramiko.WarningPolicy())
  344. try:
  345. client.connect(
  346. server,
  347. port,
  348. username=username,
  349. key_filename=keyfile,
  350. look_for_keys=True,
  351. password=password,
  352. )
  353. # except paramiko.AuthenticationException:
  354. # if password is None:
  355. # password = getpass("%s@%s's password: "%(username, server))
  356. # client.connect(server, port, username=username, password=password)
  357. # else:
  358. # raise
  359. except Exception as e:
  360. warnings.warn("*** Failed to connect to %s:%d: %r" % (server, port, e), stacklevel=2)
  361. sys.exit(1)
  362. # Don't let SIGINT kill the tunnel subprocess
  363. signal.signal(signal.SIGINT, signal.SIG_IGN)
  364. try:
  365. forward_tunnel(lport, remoteip, rport, client.get_transport())
  366. except KeyboardInterrupt:
  367. warnings.warn("SIGINT: Port forwarding stopped cleanly", stacklevel=2)
  368. sys.exit(0)
  369. except Exception as e:
  370. warnings.warn("Port forwarding stopped uncleanly: %s" % e, stacklevel=2)
  371. sys.exit(255)
  372. if sys.platform == "win32":
  373. ssh_tunnel = paramiko_tunnel
  374. else:
  375. ssh_tunnel = openssh_tunnel
  376. __all__ = [
  377. "openssh_tunnel",
  378. "paramiko_tunnel",
  379. "ssh_tunnel",
  380. "try_passwordless_ssh",
  381. "tunnel_connection",
  382. ]