server.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. """
  2. Utility for running a prompt_toolkit application in an asyncssh server.
  3. """
  4. from __future__ import annotations
  5. import asyncio
  6. import traceback
  7. from asyncio import get_running_loop
  8. from typing import Any, Callable, Coroutine, TextIO, cast
  9. import asyncssh
  10. from prompt_toolkit.application.current import AppSession, create_app_session
  11. from prompt_toolkit.data_structures import Size
  12. from prompt_toolkit.input import PipeInput, create_pipe_input
  13. from prompt_toolkit.output.vt100 import Vt100_Output
  14. __all__ = ["PromptToolkitSSHSession", "PromptToolkitSSHServer"]
  15. class PromptToolkitSSHSession(asyncssh.SSHServerSession): # type: ignore
  16. def __init__(
  17. self,
  18. interact: Callable[[PromptToolkitSSHSession], Coroutine[Any, Any, None]],
  19. *,
  20. enable_cpr: bool,
  21. ) -> None:
  22. self.interact = interact
  23. self.enable_cpr = enable_cpr
  24. self.interact_task: asyncio.Task[None] | None = None
  25. self._chan: Any | None = None
  26. self.app_session: AppSession | None = None
  27. # PipInput object, for sending input in the CLI.
  28. # (This is something that we can use in the prompt_toolkit event loop,
  29. # but still write date in manually.)
  30. self._input: PipeInput | None = None
  31. self._output: Vt100_Output | None = None
  32. # Output object. Don't render to the real stdout, but write everything
  33. # in the SSH channel.
  34. class Stdout:
  35. def write(s, data: str) -> None:
  36. try:
  37. if self._chan is not None:
  38. self._chan.write(data.replace("\n", "\r\n"))
  39. except BrokenPipeError:
  40. pass # Channel not open for sending.
  41. def isatty(s) -> bool:
  42. return True
  43. def flush(s) -> None:
  44. pass
  45. @property
  46. def encoding(s) -> str:
  47. assert self._chan is not None
  48. return str(self._chan._orig_chan.get_encoding()[0])
  49. self.stdout = cast(TextIO, Stdout())
  50. def _get_size(self) -> Size:
  51. """
  52. Callable that returns the current `Size`, required by Vt100_Output.
  53. """
  54. if self._chan is None:
  55. return Size(rows=20, columns=79)
  56. else:
  57. width, height, pixwidth, pixheight = self._chan.get_terminal_size()
  58. return Size(rows=height, columns=width)
  59. def connection_made(self, chan: Any) -> None:
  60. self._chan = chan
  61. def shell_requested(self) -> bool:
  62. return True
  63. def session_started(self) -> None:
  64. self.interact_task = get_running_loop().create_task(self._interact())
  65. async def _interact(self) -> None:
  66. if self._chan is None:
  67. # Should not happen.
  68. raise Exception("`_interact` called before `connection_made`.")
  69. if hasattr(self._chan, "set_line_mode") and self._chan._editor is not None:
  70. # Disable the line editing provided by asyncssh. Prompt_toolkit
  71. # provides the line editing.
  72. self._chan.set_line_mode(False)
  73. term = self._chan.get_terminal_type()
  74. self._output = Vt100_Output(
  75. self.stdout, self._get_size, term=term, enable_cpr=self.enable_cpr
  76. )
  77. with create_pipe_input() as self._input:
  78. with create_app_session(input=self._input, output=self._output) as session:
  79. self.app_session = session
  80. try:
  81. await self.interact(self)
  82. except BaseException:
  83. traceback.print_exc()
  84. finally:
  85. # Close the connection.
  86. self._chan.close()
  87. self._input.close()
  88. def terminal_size_changed(
  89. self, width: int, height: int, pixwidth: object, pixheight: object
  90. ) -> None:
  91. # Send resize event to the current application.
  92. if self.app_session and self.app_session.app:
  93. self.app_session.app._on_resize()
  94. def data_received(self, data: str, datatype: object) -> None:
  95. if self._input is None:
  96. # Should not happen.
  97. return
  98. self._input.send_text(data)
  99. class PromptToolkitSSHServer(asyncssh.SSHServer):
  100. """
  101. Run a prompt_toolkit application over an asyncssh server.
  102. This takes one argument, an `interact` function, which is called for each
  103. connection. This should be an asynchronous function that runs the
  104. prompt_toolkit applications. This function runs in an `AppSession`, which
  105. means that we can have multiple UI interactions concurrently.
  106. Example usage:
  107. .. code:: python
  108. async def interact(ssh_session: PromptToolkitSSHSession) -> None:
  109. await yes_no_dialog("my title", "my text").run_async()
  110. prompt_session = PromptSession()
  111. text = await prompt_session.prompt_async("Type something: ")
  112. print_formatted_text('You said: ', text)
  113. server = PromptToolkitSSHServer(interact=interact)
  114. loop = get_running_loop()
  115. loop.run_until_complete(
  116. asyncssh.create_server(
  117. lambda: MySSHServer(interact),
  118. "",
  119. port,
  120. server_host_keys=["/etc/ssh/..."],
  121. )
  122. )
  123. loop.run_forever()
  124. :param enable_cpr: When `True`, the default, try to detect whether the SSH
  125. client runs in a terminal that responds to "cursor position requests".
  126. That way, we can properly determine how much space there is available
  127. for the UI (especially for drop down menus) to render.
  128. """
  129. def __init__(
  130. self,
  131. interact: Callable[[PromptToolkitSSHSession], Coroutine[Any, Any, None]],
  132. *,
  133. enable_cpr: bool = True,
  134. ) -> None:
  135. self.interact = interact
  136. self.enable_cpr = enable_cpr
  137. def begin_auth(self, username: str) -> bool:
  138. # No authentication.
  139. return False
  140. def session_requested(self) -> PromptToolkitSSHSession:
  141. return PromptToolkitSSHSession(self.interact, enable_cpr=self.enable_cpr)