websocket.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. """Tornado websocket handler to serve a terminal interface.
  2. """
  3. # Copyright (c) Jupyter Development Team
  4. # Copyright (c) 2014, Ramalingam Saravanan <sarava@sarava.net>
  5. # Distributed under the terms of the Simplified BSD License.
  6. from __future__ import annotations
  7. import json
  8. import logging
  9. import os
  10. from typing import TYPE_CHECKING, Any
  11. import tornado.websocket
  12. from tornado import gen
  13. from tornado.concurrent import run_on_executor
  14. if TYPE_CHECKING:
  15. from terminado.management import PtyWithClients, TermManagerBase
  16. def _cast_unicode(s: str | bytes) -> str:
  17. if isinstance(s, bytes):
  18. return s.decode("utf-8")
  19. return s
  20. class TermSocket(tornado.websocket.WebSocketHandler):
  21. """Handler for a terminal websocket"""
  22. def initialize(self, term_manager: TermManagerBase) -> None:
  23. """Initialize the handler."""
  24. self.term_manager = term_manager
  25. self.term_name = ""
  26. self.size = (None, None)
  27. self.terminal: PtyWithClients | None = None
  28. self._blocking_io_executor = term_manager.blocking_io_executor
  29. self._logger = logging.getLogger(__name__)
  30. self._user_command = ""
  31. # Enable if the environment variable LOG_TERMINAL_OUTPUT is "true"
  32. self._enable_output_logging = str.lower(os.getenv("LOG_TERMINAL_OUTPUT", "false")) == "true"
  33. def origin_check(self, origin: str | None = None) -> bool:
  34. """Deprecated: backward-compat for terminado <= 0.5."""
  35. origin = origin or self.request.headers.get("Origin", "")
  36. assert origin is not None
  37. return self.check_origin(origin)
  38. def open(self, url_component: Any = None) -> None: # type:ignore[override]
  39. """Websocket connection opened.
  40. Call our terminal manager to get a terminal, and connect to it as a
  41. client.
  42. """
  43. # Jupyter has a mixin to ping websockets and keep connections through
  44. # proxies alive. Call super() to allow that to set up:
  45. super().open(url_component)
  46. self._logger.info("TermSocket.open: %s", url_component)
  47. url_component = _cast_unicode(url_component)
  48. self.term_name = url_component or "tty"
  49. self.terminal = self.term_manager.get_terminal(url_component)
  50. self.terminal.clients.append(self)
  51. self.send_json_message(["setup", {}])
  52. self._logger.info("TermSocket.open: Opened %s", self.term_name)
  53. # Now drain the preopen buffer, if reconnect.
  54. buffered = ""
  55. preopen_buffer = self.terminal.read_buffer.copy()
  56. while True:
  57. if not preopen_buffer:
  58. break
  59. s = preopen_buffer.popleft()
  60. buffered += s
  61. if buffered:
  62. self.on_pty_read(buffered)
  63. def on_pty_read(self, text: str) -> None:
  64. """Data read from pty; send to frontend"""
  65. self.send_json_message(["stdout", text])
  66. def send_json_message(self, content: Any) -> None:
  67. """Send a json message on the socket."""
  68. json_msg = json.dumps(content)
  69. self.write_message(json_msg)
  70. if self._enable_output_logging and content[0] == "stdout" and isinstance(content[1], str):
  71. self.log_terminal_output(f"STDOUT: {content[1]}")
  72. @gen.coroutine
  73. def on_message(self, message: str) -> None: # type:ignore[misc]
  74. """Handle incoming websocket message
  75. We send JSON arrays, where the first element is a string indicating
  76. what kind of message this is. Data associated with the message follows.
  77. """
  78. # logging.info("TermSocket.on_message: %s - (%s) %s", self.term_name, type(message), len(message) if isinstance(message, bytes) else message[:250])
  79. command = json.loads(message)
  80. msg_type = command[0]
  81. assert self.terminal is not None
  82. if msg_type == "stdin":
  83. yield self.stdin_to_ptyproc(command[1])
  84. if self._enable_output_logging:
  85. if command[1] == "\r":
  86. self.log_terminal_output(f"STDIN: {self._user_command}")
  87. self._user_command = ""
  88. else:
  89. self._user_command += command[1]
  90. elif msg_type == "set_size":
  91. self.size = command[1:3]
  92. self.terminal.resize_to_smallest()
  93. def on_close(self) -> None:
  94. """Handle websocket closing.
  95. Disconnect from our terminal, and tell the terminal manager we're
  96. disconnecting.
  97. """
  98. self._logger.info("Websocket closed")
  99. if self.terminal:
  100. self.terminal.clients.remove(self)
  101. self.terminal.resize_to_smallest()
  102. self.term_manager.client_disconnected(self)
  103. def on_pty_died(self) -> None:
  104. """Terminal closed: tell the frontend, and close the socket."""
  105. self.send_json_message(["disconnect", 1])
  106. self.close()
  107. self.terminal = None
  108. def log_terminal_output(self, log: str = "") -> None:
  109. """
  110. Logs the terminal input/output
  111. :param log: log line to write
  112. :return:
  113. """
  114. self._logger.debug(log)
  115. @run_on_executor(executor="_blocking_io_executor")
  116. def stdin_to_ptyproc(self, text: str) -> None:
  117. """Handles stdin messages sent on the websocket.
  118. This is a blocking call that should NOT be performed inside the
  119. server primary event loop thread. Messages must be handled
  120. asynchronously to prevent blocking on the PTY buffer.
  121. """
  122. if self.terminal is not None:
  123. self.terminal.ptyproc.write(text)