websocket.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. """Tornado handlers for WebSocket <-> ZMQ sockets."""
  2. # Copyright (c) Jupyter Development Team.
  3. # Distributed under the terms of the Modified BSD License.
  4. from jupyter_core.utils import ensure_async
  5. from tornado import web
  6. from tornado.websocket import WebSocketHandler
  7. from jupyter_server.auth.decorator import ws_authenticated
  8. from jupyter_server.base.handlers import JupyterHandler
  9. from jupyter_server.base.websocket import WebSocketMixin
  10. AUTH_RESOURCE = "kernels"
  11. class KernelWebsocketHandler(WebSocketMixin, WebSocketHandler, JupyterHandler): # type:ignore[misc]
  12. """The kernels websocket should connect"""
  13. auth_resource = AUTH_RESOURCE
  14. @property
  15. def kernel_websocket_connection_class(self):
  16. """The kernel websocket connection class."""
  17. return self.settings.get("kernel_websocket_connection_class")
  18. def set_default_headers(self):
  19. """Undo the set_default_headers in JupyterHandler
  20. which doesn't make sense for websockets
  21. """
  22. def get_compression_options(self):
  23. """Get the socket connection options."""
  24. return self.settings.get("websocket_compression_options", None)
  25. async def pre_get(self):
  26. """Handle a pre_get."""
  27. user = self.current_user
  28. # authorize the user.
  29. authorized = await ensure_async(
  30. self.authorizer.is_authorized(self, user, "execute", "kernels")
  31. )
  32. if not authorized:
  33. raise web.HTTPError(403)
  34. kernel = self.kernel_manager.get_kernel(self.kernel_id)
  35. self.connection = self.kernel_websocket_connection_class(
  36. parent=kernel, websocket_handler=self, config=self.config
  37. )
  38. if self.get_argument("session_id", None):
  39. self.connection.session.session = self.get_argument("session_id")
  40. else:
  41. self.log.warning("No session ID specified")
  42. # For backwards compatibility with older versions
  43. # of the websocket connection, call a prepare method if found.
  44. if hasattr(self.connection, "prepare"):
  45. await self.connection.prepare()
  46. @ws_authenticated
  47. async def get(self, kernel_id):
  48. """Handle a get request for a kernel."""
  49. self.kernel_id = kernel_id
  50. await self.pre_get()
  51. await super().get(kernel_id=kernel_id)
  52. async def open(self, kernel_id):
  53. """Open a kernel websocket."""
  54. # Need to call super here to make sure we
  55. # begin a ping-pong loop with the client.
  56. super().open()
  57. # Wait for the kernel to emit an idle status.
  58. self.log.info(f"Connecting to kernel {self.kernel_id}.")
  59. await self.connection.connect()
  60. def on_message(self, ws_message):
  61. """Get a kernel message from the websocket and turn it into a ZMQ message."""
  62. self.connection.handle_incoming_message(ws_message)
  63. def on_close(self):
  64. """Handle a socket closure."""
  65. self.connection.disconnect()
  66. self.connection = None
  67. def select_subprotocol(self, subprotocols):
  68. """Select the sub protocol for the socket."""
  69. preferred_protocol = self.connection.kernel_ws_protocol
  70. if preferred_protocol is None:
  71. preferred_protocol = "v1.kernel.websocket.jupyter.org"
  72. elif preferred_protocol == "":
  73. preferred_protocol = None
  74. selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None
  75. # None is the default, "legacy" protocol
  76. return selected_subprotocol