| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314 |
- """Gateway API handlers."""
- # Copyright (c) Jupyter Development Team.
- # Distributed under the terms of the Modified BSD License.
- from __future__ import annotations
- import asyncio
- import logging
- import mimetypes
- import os
- import random
- import warnings
- from typing import Any, Optional, cast
- from jupyter_client.session import Session
- from tornado import web
- from tornado.concurrent import Future
- from tornado.escape import json_decode, url_escape, utf8
- from tornado.httpclient import HTTPRequest
- from tornado.ioloop import IOLoop, PeriodicCallback
- from tornado.websocket import WebSocketHandler, websocket_connect
- from traitlets.config.configurable import LoggingConfigurable
- from ..base.handlers import APIHandler, JupyterHandler
- from ..utils import url_path_join
- from .gateway_client import GatewayClient
- warnings.warn(
- "The jupyter_server.gateway.handlers module is deprecated and will not be supported in Jupyter Server 3.0",
- DeprecationWarning,
- stacklevel=2,
- )
- # Keepalive ping interval (default: 30 seconds)
- GATEWAY_WS_PING_INTERVAL_SECS = int(os.getenv("GATEWAY_WS_PING_INTERVAL_SECS", "30"))
- class WebSocketChannelsHandler(WebSocketHandler, JupyterHandler):
- """Gateway web socket channels handler."""
- session = None
- gateway = None
- kernel_id = None
- ping_callback = None
- def check_origin(self, origin=None):
- """Check origin for the socket."""
- return JupyterHandler.check_origin(self, origin)
- def set_default_headers(self):
- """Undo the set_default_headers in JupyterHandler which doesn't make sense for websockets"""
- def get_compression_options(self):
- """Get the compression options for the socket."""
- # use deflate compress websocket
- return {}
- def authenticate(self):
- """Run before finishing the GET request
- Extend this method to add logic that should fire before
- the websocket finishes completing.
- """
- # authenticate the request before opening the websocket
- if self.current_user is None:
- self.log.warning("Couldn't authenticate WebSocket connection")
- raise web.HTTPError(403)
- if self.get_argument("session_id", None):
- assert self.session is not None
- self.session.session = self.get_argument("session_id") # type:ignore[unreachable]
- else:
- self.log.warning("No session ID specified")
- def initialize(self):
- """Initialize the socket."""
- self.log.debug("Initializing websocket connection %s", self.request.path)
- self.session = Session(config=self.config)
- self.gateway = GatewayWebSocketClient(gateway_url=GatewayClient.instance().url)
- async def get(self, kernel_id, *args, **kwargs):
- """Get the socket."""
- self.authenticate()
- self.kernel_id = kernel_id
- kwargs["kernel_id"] = kernel_id
- await super().get(*args, **kwargs)
- def send_ping(self):
- """Send a ping to the socket."""
- if self.ws_connection is None and self.ping_callback is not None:
- self.ping_callback.stop() # type:ignore[unreachable]
- return
- self.ping(b"")
- def open(self, kernel_id, *args, **kwargs):
- """Handle web socket connection open to notebook server and delegate to gateway web socket handler"""
- self.ping_callback = PeriodicCallback(self.send_ping, GATEWAY_WS_PING_INTERVAL_SECS * 1000)
- self.ping_callback.start()
- assert self.gateway is not None
- self.gateway.on_open(
- kernel_id=kernel_id,
- message_callback=self.write_message,
- compression_options=self.get_compression_options(),
- )
- def on_message(self, message):
- """Forward message to gateway web socket handler."""
- assert self.gateway is not None
- self.gateway.on_message(message)
- def write_message(self, message, binary=False):
- """Send message back to notebook client. This is called via callback from self.gateway._read_messages."""
- if self.ws_connection: # prevent WebSocketClosedError
- if isinstance(message, bytes):
- binary = True
- super().write_message(message, binary=binary)
- elif self.log.isEnabledFor(logging.DEBUG):
- msg_summary = WebSocketChannelsHandler._get_message_summary(json_decode(utf8(message)))
- self.log.debug(
- f"Notebook client closed websocket connection - message dropped: {msg_summary}"
- )
- def on_close(self):
- """Handle a closing socket."""
- self.log.debug("Closing websocket connection %s", self.request.path)
- assert self.gateway is not None
- self.gateway.on_close()
- super().on_close()
- @staticmethod
- def _get_message_summary(message):
- """Get a summary of a message."""
- summary = []
- message_type = message["msg_type"]
- summary.append(f"type: {message_type}")
- if message_type == "status":
- summary.append(", state: {}".format(message["content"]["execution_state"]))
- elif message_type == "error":
- summary.append(
- ", {}:{}:{}".format(
- message["content"]["ename"],
- message["content"]["evalue"],
- message["content"]["traceback"],
- )
- )
- else:
- summary.append(", ...") # don't display potentially sensitive data
- return "".join(summary)
- class GatewayWebSocketClient(LoggingConfigurable):
- """Proxy web socket connection to a kernel/enterprise gateway."""
- def __init__(self, **kwargs):
- """Initialize the gateway web socket client."""
- super().__init__()
- self.kernel_id = None
- self.ws = None
- self.ws_future: Future[Any] = Future()
- self.disconnected = False
- self.retry = 0
- async def _connect(self, kernel_id, message_callback):
- """Connect to the socket."""
- # websocket is initialized before connection
- self.ws = None
- self.kernel_id = kernel_id
- client = GatewayClient.instance()
- assert client.ws_url is not None
- ws_url = url_path_join(
- client.ws_url,
- client.kernels_endpoint,
- url_escape(kernel_id),
- "channels",
- )
- self.log.info(f"Connecting to {ws_url}")
- kwargs: dict[str, Any] = {}
- kwargs = client.load_connection_args(**kwargs)
- request = HTTPRequest(ws_url, **kwargs)
- self.ws_future = cast("Future[Any]", websocket_connect(request))
- self.ws_future.add_done_callback(self._connection_done)
- loop = IOLoop.current()
- loop.add_future(self.ws_future, lambda future: self._read_messages(message_callback))
- def _connection_done(self, fut):
- """Handle a finished connection."""
- if (
- not self.disconnected and fut.exception() is None
- ): # prevent concurrent.futures._base.CancelledError
- self.ws = fut.result()
- self.retry = 0
- self.log.debug(f"Connection is ready: ws: {self.ws}")
- else:
- self.log.warning(
- "Websocket connection has been closed via client disconnect or due to error. "
- f"Kernel with ID '{self.kernel_id}' may not be terminated on GatewayClient: {GatewayClient.instance().url}"
- )
- def _disconnect(self):
- """Handle a disconnect."""
- self.disconnected = True
- if self.ws is not None:
- # Close connection
- self.ws.close()
- elif not self.ws_future.done():
- # Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally
- self.ws_future.cancel()
- self.log.debug(f"_disconnect: future cancelled, disconnected: {self.disconnected}")
- async def _read_messages(self, callback):
- """Read messages from gateway server."""
- while self.ws is not None:
- message = None
- if not self.disconnected:
- try:
- message = await self.ws.read_message()
- except Exception as e:
- self.log.error(
- f"Exception reading message from websocket: {e}"
- ) # , exc_info=True)
- if message is None:
- if not self.disconnected:
- self.log.warning(f"Lost connection to Gateway: {self.kernel_id}")
- break
- callback(
- message
- ) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open)
- else: # ws cancelled - stop reading
- break
- # NOTE(esevan): if websocket is not disconnected by client, try to reconnect.
- if not self.disconnected and self.retry < GatewayClient.instance().gateway_retry_max:
- jitter = random.randint(10, 100) * 0.01 # noqa: S311
- retry_interval = (
- min(
- GatewayClient.instance().gateway_retry_interval * (2**self.retry),
- GatewayClient.instance().gateway_retry_interval_max,
- )
- + jitter
- )
- self.retry += 1
- self.log.info(
- "Attempting to re-establish the connection to Gateway in %s secs (%s/%s): %s",
- retry_interval,
- self.retry,
- GatewayClient.instance().gateway_retry_max,
- self.kernel_id,
- )
- await asyncio.sleep(retry_interval)
- loop = IOLoop.current()
- loop.spawn_callback(self._connect, self.kernel_id, callback)
- def on_open(self, kernel_id, message_callback, **kwargs):
- """Web socket connection open against gateway server."""
- loop = IOLoop.current()
- loop.spawn_callback(self._connect, kernel_id, message_callback)
- def on_message(self, message):
- """Send message to gateway server."""
- if self.ws is None:
- loop = IOLoop.current()
- loop.add_future(self.ws_future, lambda future: self._write_message(message))
- else:
- self._write_message(message)
- def _write_message(self, message):
- """Send message to gateway server."""
- try:
- if not self.disconnected and self.ws is not None:
- self.ws.write_message(message)
- except Exception as e:
- self.log.error(f"Exception writing message to websocket: {e}") # , exc_info=True)
- def on_close(self):
- """Web socket closed event."""
- self._disconnect()
- class GatewayResourceHandler(APIHandler):
- """Retrieves resources for specific kernelspec definitions from kernel/enterprise gateway."""
- @web.authenticated
- async def get(self, kernel_name, path, include_body=True):
- """Get a gateway resource by name and path."""
- mimetype: Optional[str] = None
- ksm = self.kernel_spec_manager
- kernel_spec_res = await ksm.get_kernel_spec_resource( # type:ignore[attr-defined]
- kernel_name, path
- )
- if kernel_spec_res is None:
- self.log.warning(
- f"Kernelspec resource '{path}' for '{kernel_name}' not found. Gateway may not support"
- " resource serving."
- )
- else:
- mimetype = mimetypes.guess_type(path)[0] or "text/plain"
- self.finish(kernel_spec_res, set_content_type=mimetype)
- from ..services.kernels.handlers import _kernel_id_regex
- from ..services.kernelspecs.handlers import kernel_name_regex
- default_handlers = [
- (r"/api/kernels/%s/channels" % _kernel_id_regex, WebSocketChannelsHandler),
- (r"/kernelspecs/%s/(?P<path>.*)" % kernel_name_regex, GatewayResourceHandler),
- ]
|