server.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. """
  2. Telnet server.
  3. """
  4. from __future__ import annotations
  5. import asyncio
  6. import contextvars
  7. import socket
  8. from asyncio import get_running_loop
  9. from typing import Any, Callable, Coroutine, TextIO, cast
  10. from prompt_toolkit.application.current import create_app_session, get_app
  11. from prompt_toolkit.application.run_in_terminal import run_in_terminal
  12. from prompt_toolkit.data_structures import Size
  13. from prompt_toolkit.formatted_text import AnyFormattedText, to_formatted_text
  14. from prompt_toolkit.input import PipeInput, create_pipe_input
  15. from prompt_toolkit.output.vt100 import Vt100_Output
  16. from prompt_toolkit.renderer import print_formatted_text as print_formatted_text
  17. from prompt_toolkit.styles import BaseStyle, DummyStyle
  18. from .log import logger
  19. from .protocol import (
  20. DO,
  21. ECHO,
  22. IAC,
  23. LINEMODE,
  24. MODE,
  25. NAWS,
  26. SB,
  27. SE,
  28. SEND,
  29. SUPPRESS_GO_AHEAD,
  30. TTYPE,
  31. WILL,
  32. TelnetProtocolParser,
  33. )
  34. __all__ = [
  35. "TelnetServer",
  36. ]
  37. def int2byte(number: int) -> bytes:
  38. return bytes((number,))
  39. def _initialize_telnet(connection: socket.socket) -> None:
  40. logger.info("Initializing telnet connection")
  41. # Iac Do Linemode
  42. connection.send(IAC + DO + LINEMODE)
  43. # Suppress Go Ahead. (This seems important for Putty to do correct echoing.)
  44. # This will allow bi-directional operation.
  45. connection.send(IAC + WILL + SUPPRESS_GO_AHEAD)
  46. # Iac sb
  47. connection.send(IAC + SB + LINEMODE + MODE + int2byte(0) + IAC + SE)
  48. # IAC Will Echo
  49. connection.send(IAC + WILL + ECHO)
  50. # Negotiate window size
  51. connection.send(IAC + DO + NAWS)
  52. # Negotiate terminal type
  53. # Assume the client will accept the negotiation with `IAC + WILL + TTYPE`
  54. connection.send(IAC + DO + TTYPE)
  55. # We can then select the first terminal type supported by the client,
  56. # which is generally the best type the client supports
  57. # The client should reply with a `IAC + SB + TTYPE + IS + ttype + IAC + SE`
  58. connection.send(IAC + SB + TTYPE + SEND + IAC + SE)
  59. class _ConnectionStdout:
  60. """
  61. Wrapper around socket which provides `write` and `flush` methods for the
  62. Vt100_Output output.
  63. """
  64. def __init__(self, connection: socket.socket, encoding: str) -> None:
  65. self._encoding = encoding
  66. self._connection = connection
  67. self._errors = "strict"
  68. self._buffer: list[bytes] = []
  69. self._closed = False
  70. def write(self, data: str) -> None:
  71. data = data.replace("\n", "\r\n")
  72. self._buffer.append(data.encode(self._encoding, errors=self._errors))
  73. self.flush()
  74. def isatty(self) -> bool:
  75. return True
  76. def flush(self) -> None:
  77. try:
  78. if not self._closed:
  79. self._connection.send(b"".join(self._buffer))
  80. except OSError as e:
  81. logger.warning(f"Couldn't send data over socket: {e}")
  82. self._buffer = []
  83. def close(self) -> None:
  84. self._closed = True
  85. @property
  86. def encoding(self) -> str:
  87. return self._encoding
  88. @property
  89. def errors(self) -> str:
  90. return self._errors
  91. class TelnetConnection:
  92. """
  93. Class that represents one Telnet connection.
  94. """
  95. def __init__(
  96. self,
  97. conn: socket.socket,
  98. addr: tuple[str, int],
  99. interact: Callable[[TelnetConnection], Coroutine[Any, Any, None]],
  100. server: TelnetServer,
  101. encoding: str,
  102. style: BaseStyle | None,
  103. vt100_input: PipeInput,
  104. enable_cpr: bool = True,
  105. ) -> None:
  106. self.conn = conn
  107. self.addr = addr
  108. self.interact = interact
  109. self.server = server
  110. self.encoding = encoding
  111. self.style = style
  112. self._closed = False
  113. self._ready = asyncio.Event()
  114. self.vt100_input = vt100_input
  115. self.enable_cpr = enable_cpr
  116. self.vt100_output: Vt100_Output | None = None
  117. # Create "Output" object.
  118. self.size = Size(rows=40, columns=79)
  119. # Initialize.
  120. _initialize_telnet(conn)
  121. # Create output.
  122. def get_size() -> Size:
  123. return self.size
  124. self.stdout = cast(TextIO, _ConnectionStdout(conn, encoding=encoding))
  125. def data_received(data: bytes) -> None:
  126. """TelnetProtocolParser 'data_received' callback"""
  127. self.vt100_input.send_bytes(data)
  128. def size_received(rows: int, columns: int) -> None:
  129. """TelnetProtocolParser 'size_received' callback"""
  130. self.size = Size(rows=rows, columns=columns)
  131. if self.vt100_output is not None and self.context:
  132. self.context.run(lambda: get_app()._on_resize())
  133. def ttype_received(ttype: str) -> None:
  134. """TelnetProtocolParser 'ttype_received' callback"""
  135. self.vt100_output = Vt100_Output(
  136. self.stdout, get_size, term=ttype, enable_cpr=enable_cpr
  137. )
  138. self._ready.set()
  139. self.parser = TelnetProtocolParser(data_received, size_received, ttype_received)
  140. self.context: contextvars.Context | None = None
  141. async def run_application(self) -> None:
  142. """
  143. Run application.
  144. """
  145. def handle_incoming_data() -> None:
  146. data = self.conn.recv(1024)
  147. if data:
  148. self.feed(data)
  149. else:
  150. # Connection closed by client.
  151. logger.info("Connection closed by client. {!r} {!r}".format(*self.addr))
  152. self.close()
  153. # Add reader.
  154. loop = get_running_loop()
  155. loop.add_reader(self.conn, handle_incoming_data)
  156. try:
  157. # Wait for v100_output to be properly instantiated
  158. await self._ready.wait()
  159. with create_app_session(input=self.vt100_input, output=self.vt100_output):
  160. self.context = contextvars.copy_context()
  161. await self.interact(self)
  162. finally:
  163. self.close()
  164. def feed(self, data: bytes) -> None:
  165. """
  166. Handler for incoming data. (Called by TelnetServer.)
  167. """
  168. self.parser.feed(data)
  169. def close(self) -> None:
  170. """
  171. Closed by client.
  172. """
  173. if not self._closed:
  174. self._closed = True
  175. self.vt100_input.close()
  176. get_running_loop().remove_reader(self.conn)
  177. self.conn.close()
  178. self.stdout.close()
  179. def send(self, formatted_text: AnyFormattedText) -> None:
  180. """
  181. Send text to the client.
  182. """
  183. if self.vt100_output is None:
  184. return
  185. formatted_text = to_formatted_text(formatted_text)
  186. print_formatted_text(
  187. self.vt100_output, formatted_text, self.style or DummyStyle()
  188. )
  189. def send_above_prompt(self, formatted_text: AnyFormattedText) -> None:
  190. """
  191. Send text to the client.
  192. This is asynchronous, returns a `Future`.
  193. """
  194. formatted_text = to_formatted_text(formatted_text)
  195. return self._run_in_terminal(lambda: self.send(formatted_text))
  196. def _run_in_terminal(self, func: Callable[[], None]) -> None:
  197. # Make sure that when an application was active for this connection,
  198. # that we print the text above the application.
  199. if self.context:
  200. self.context.run(run_in_terminal, func)
  201. else:
  202. raise RuntimeError("Called _run_in_terminal outside `run_application`.")
  203. def erase_screen(self) -> None:
  204. """
  205. Erase the screen and move the cursor to the top.
  206. """
  207. if self.vt100_output is None:
  208. return
  209. self.vt100_output.erase_screen()
  210. self.vt100_output.cursor_goto(0, 0)
  211. self.vt100_output.flush()
  212. async def _dummy_interact(connection: TelnetConnection) -> None:
  213. pass
  214. class TelnetServer:
  215. """
  216. Telnet server implementation.
  217. Example::
  218. async def interact(connection):
  219. connection.send("Welcome")
  220. session = PromptSession()
  221. result = await session.prompt_async(message="Say something: ")
  222. connection.send(f"You said: {result}\n")
  223. async def main():
  224. server = TelnetServer(interact=interact, port=2323)
  225. await server.run()
  226. """
  227. def __init__(
  228. self,
  229. host: str = "127.0.0.1",
  230. port: int = 23,
  231. interact: Callable[
  232. [TelnetConnection], Coroutine[Any, Any, None]
  233. ] = _dummy_interact,
  234. encoding: str = "utf-8",
  235. style: BaseStyle | None = None,
  236. enable_cpr: bool = True,
  237. ) -> None:
  238. self.host = host
  239. self.port = port
  240. self.interact = interact
  241. self.encoding = encoding
  242. self.style = style
  243. self.enable_cpr = enable_cpr
  244. self._run_task: asyncio.Task[None] | None = None
  245. self._application_tasks: list[asyncio.Task[None]] = []
  246. self.connections: set[TelnetConnection] = set()
  247. @classmethod
  248. def _create_socket(cls, host: str, port: int) -> socket.socket:
  249. # Create and bind socket
  250. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  251. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  252. s.bind((host, port))
  253. s.listen(4)
  254. return s
  255. async def run(self, ready_cb: Callable[[], None] | None = None) -> None:
  256. """
  257. Run the telnet server, until this gets cancelled.
  258. :param ready_cb: Callback that will be called at the point that we're
  259. actually listening.
  260. """
  261. socket = self._create_socket(self.host, self.port)
  262. logger.info(
  263. "Listening for telnet connections on %s port %r", self.host, self.port
  264. )
  265. get_running_loop().add_reader(socket, lambda: self._accept(socket))
  266. if ready_cb:
  267. ready_cb()
  268. try:
  269. # Run forever, until cancelled.
  270. await asyncio.Future()
  271. finally:
  272. get_running_loop().remove_reader(socket)
  273. socket.close()
  274. # Wait for all applications to finish.
  275. for t in self._application_tasks:
  276. t.cancel()
  277. # (This is similar to
  278. # `Application.cancel_and_wait_for_background_tasks`. We wait for the
  279. # background tasks to complete, but don't propagate exceptions, because
  280. # we can't use `ExceptionGroup` yet.)
  281. if len(self._application_tasks) > 0:
  282. await asyncio.wait(
  283. self._application_tasks,
  284. timeout=None,
  285. return_when=asyncio.ALL_COMPLETED,
  286. )
  287. def start(self) -> None:
  288. """
  289. Deprecated: Use `.run()` instead.
  290. Start the telnet server (stop by calling and awaiting `stop()`).
  291. """
  292. if self._run_task is not None:
  293. # Already running.
  294. return
  295. self._run_task = get_running_loop().create_task(self.run())
  296. async def stop(self) -> None:
  297. """
  298. Deprecated: Use `.run()` instead.
  299. Stop a telnet server that was started using `.start()` and wait for the
  300. cancellation to complete.
  301. """
  302. if self._run_task is not None:
  303. self._run_task.cancel()
  304. try:
  305. await self._run_task
  306. except asyncio.CancelledError:
  307. pass
  308. def _accept(self, listen_socket: socket.socket) -> None:
  309. """
  310. Accept new incoming connection.
  311. """
  312. conn, addr = listen_socket.accept()
  313. logger.info("New connection %r %r", *addr)
  314. # Run application for this connection.
  315. async def run() -> None:
  316. try:
  317. with create_pipe_input() as vt100_input:
  318. connection = TelnetConnection(
  319. conn,
  320. addr,
  321. self.interact,
  322. self,
  323. encoding=self.encoding,
  324. style=self.style,
  325. vt100_input=vt100_input,
  326. enable_cpr=self.enable_cpr,
  327. )
  328. self.connections.add(connection)
  329. logger.info("Starting interaction %r %r", *addr)
  330. try:
  331. await connection.run_application()
  332. finally:
  333. self.connections.remove(connection)
  334. logger.info("Stopping interaction %r %r", *addr)
  335. except EOFError:
  336. # Happens either when the connection is closed by the client
  337. # (e.g., when the user types 'control-]', then 'quit' in the
  338. # telnet client) or when the user types control-d in a prompt
  339. # and this is not handled by the interact function.
  340. logger.info("Unhandled EOFError in telnet application.")
  341. except KeyboardInterrupt:
  342. # Unhandled control-c propagated by a prompt.
  343. logger.info("Unhandled KeyboardInterrupt in telnet application.")
  344. except BaseException as e:
  345. print(f"Got {type(e).__name__}", e)
  346. import traceback
  347. traceback.print_exc()
  348. finally:
  349. self._application_tasks.remove(task)
  350. task = get_running_loop().create_task(run())
  351. self._application_tasks.append(task)