handlers.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. """Gateway API handlers."""
  2. # Copyright (c) Jupyter Development Team.
  3. # Distributed under the terms of the Modified BSD License.
  4. from __future__ import annotations
  5. import asyncio
  6. import logging
  7. import mimetypes
  8. import os
  9. import random
  10. import warnings
  11. from typing import Any, Optional, cast
  12. from jupyter_client.session import Session
  13. from tornado import web
  14. from tornado.concurrent import Future
  15. from tornado.escape import json_decode, url_escape, utf8
  16. from tornado.httpclient import HTTPRequest
  17. from tornado.ioloop import IOLoop, PeriodicCallback
  18. from tornado.websocket import WebSocketHandler, websocket_connect
  19. from traitlets.config.configurable import LoggingConfigurable
  20. from ..base.handlers import APIHandler, JupyterHandler
  21. from ..utils import url_path_join
  22. from .gateway_client import GatewayClient
  23. warnings.warn(
  24. "The jupyter_server.gateway.handlers module is deprecated and will not be supported in Jupyter Server 3.0",
  25. DeprecationWarning,
  26. stacklevel=2,
  27. )
  28. # Keepalive ping interval (default: 30 seconds)
  29. GATEWAY_WS_PING_INTERVAL_SECS = int(os.getenv("GATEWAY_WS_PING_INTERVAL_SECS", "30"))
  30. class WebSocketChannelsHandler(WebSocketHandler, JupyterHandler):
  31. """Gateway web socket channels handler."""
  32. session = None
  33. gateway = None
  34. kernel_id = None
  35. ping_callback = None
  36. def check_origin(self, origin=None):
  37. """Check origin for the socket."""
  38. return JupyterHandler.check_origin(self, origin)
  39. def set_default_headers(self):
  40. """Undo the set_default_headers in JupyterHandler which doesn't make sense for websockets"""
  41. def get_compression_options(self):
  42. """Get the compression options for the socket."""
  43. # use deflate compress websocket
  44. return {}
  45. def authenticate(self):
  46. """Run before finishing the GET request
  47. Extend this method to add logic that should fire before
  48. the websocket finishes completing.
  49. """
  50. # authenticate the request before opening the websocket
  51. if self.current_user is None:
  52. self.log.warning("Couldn't authenticate WebSocket connection")
  53. raise web.HTTPError(403)
  54. if self.get_argument("session_id", None):
  55. assert self.session is not None
  56. self.session.session = self.get_argument("session_id") # type:ignore[unreachable]
  57. else:
  58. self.log.warning("No session ID specified")
  59. def initialize(self):
  60. """Initialize the socket."""
  61. self.log.debug("Initializing websocket connection %s", self.request.path)
  62. self.session = Session(config=self.config)
  63. self.gateway = GatewayWebSocketClient(gateway_url=GatewayClient.instance().url)
  64. async def get(self, kernel_id, *args, **kwargs):
  65. """Get the socket."""
  66. self.authenticate()
  67. self.kernel_id = kernel_id
  68. kwargs["kernel_id"] = kernel_id
  69. await super().get(*args, **kwargs)
  70. def send_ping(self):
  71. """Send a ping to the socket."""
  72. if self.ws_connection is None and self.ping_callback is not None:
  73. self.ping_callback.stop() # type:ignore[unreachable]
  74. return
  75. self.ping(b"")
  76. def open(self, kernel_id, *args, **kwargs):
  77. """Handle web socket connection open to notebook server and delegate to gateway web socket handler"""
  78. self.ping_callback = PeriodicCallback(self.send_ping, GATEWAY_WS_PING_INTERVAL_SECS * 1000)
  79. self.ping_callback.start()
  80. assert self.gateway is not None
  81. self.gateway.on_open(
  82. kernel_id=kernel_id,
  83. message_callback=self.write_message,
  84. compression_options=self.get_compression_options(),
  85. )
  86. def on_message(self, message):
  87. """Forward message to gateway web socket handler."""
  88. assert self.gateway is not None
  89. self.gateway.on_message(message)
  90. def write_message(self, message, binary=False):
  91. """Send message back to notebook client. This is called via callback from self.gateway._read_messages."""
  92. if self.ws_connection: # prevent WebSocketClosedError
  93. if isinstance(message, bytes):
  94. binary = True
  95. super().write_message(message, binary=binary)
  96. elif self.log.isEnabledFor(logging.DEBUG):
  97. msg_summary = WebSocketChannelsHandler._get_message_summary(json_decode(utf8(message)))
  98. self.log.debug(
  99. f"Notebook client closed websocket connection - message dropped: {msg_summary}"
  100. )
  101. def on_close(self):
  102. """Handle a closing socket."""
  103. self.log.debug("Closing websocket connection %s", self.request.path)
  104. assert self.gateway is not None
  105. self.gateway.on_close()
  106. super().on_close()
  107. @staticmethod
  108. def _get_message_summary(message):
  109. """Get a summary of a message."""
  110. summary = []
  111. message_type = message["msg_type"]
  112. summary.append(f"type: {message_type}")
  113. if message_type == "status":
  114. summary.append(", state: {}".format(message["content"]["execution_state"]))
  115. elif message_type == "error":
  116. summary.append(
  117. ", {}:{}:{}".format(
  118. message["content"]["ename"],
  119. message["content"]["evalue"],
  120. message["content"]["traceback"],
  121. )
  122. )
  123. else:
  124. summary.append(", ...") # don't display potentially sensitive data
  125. return "".join(summary)
  126. class GatewayWebSocketClient(LoggingConfigurable):
  127. """Proxy web socket connection to a kernel/enterprise gateway."""
  128. def __init__(self, **kwargs):
  129. """Initialize the gateway web socket client."""
  130. super().__init__()
  131. self.kernel_id = None
  132. self.ws = None
  133. self.ws_future: Future[Any] = Future()
  134. self.disconnected = False
  135. self.retry = 0
  136. async def _connect(self, kernel_id, message_callback):
  137. """Connect to the socket."""
  138. # websocket is initialized before connection
  139. self.ws = None
  140. self.kernel_id = kernel_id
  141. client = GatewayClient.instance()
  142. assert client.ws_url is not None
  143. ws_url = url_path_join(
  144. client.ws_url,
  145. client.kernels_endpoint,
  146. url_escape(kernel_id),
  147. "channels",
  148. )
  149. self.log.info(f"Connecting to {ws_url}")
  150. kwargs: dict[str, Any] = {}
  151. kwargs = client.load_connection_args(**kwargs)
  152. request = HTTPRequest(ws_url, **kwargs)
  153. self.ws_future = cast("Future[Any]", websocket_connect(request))
  154. self.ws_future.add_done_callback(self._connection_done)
  155. loop = IOLoop.current()
  156. loop.add_future(self.ws_future, lambda future: self._read_messages(message_callback))
  157. def _connection_done(self, fut):
  158. """Handle a finished connection."""
  159. if (
  160. not self.disconnected and fut.exception() is None
  161. ): # prevent concurrent.futures._base.CancelledError
  162. self.ws = fut.result()
  163. self.retry = 0
  164. self.log.debug(f"Connection is ready: ws: {self.ws}")
  165. else:
  166. self.log.warning(
  167. "Websocket connection has been closed via client disconnect or due to error. "
  168. f"Kernel with ID '{self.kernel_id}' may not be terminated on GatewayClient: {GatewayClient.instance().url}"
  169. )
  170. def _disconnect(self):
  171. """Handle a disconnect."""
  172. self.disconnected = True
  173. if self.ws is not None:
  174. # Close connection
  175. self.ws.close()
  176. elif not self.ws_future.done():
  177. # Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally
  178. self.ws_future.cancel()
  179. self.log.debug(f"_disconnect: future cancelled, disconnected: {self.disconnected}")
  180. async def _read_messages(self, callback):
  181. """Read messages from gateway server."""
  182. while self.ws is not None:
  183. message = None
  184. if not self.disconnected:
  185. try:
  186. message = await self.ws.read_message()
  187. except Exception as e:
  188. self.log.error(
  189. f"Exception reading message from websocket: {e}"
  190. ) # , exc_info=True)
  191. if message is None:
  192. if not self.disconnected:
  193. self.log.warning(f"Lost connection to Gateway: {self.kernel_id}")
  194. break
  195. callback(
  196. message
  197. ) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open)
  198. else: # ws cancelled - stop reading
  199. break
  200. # NOTE(esevan): if websocket is not disconnected by client, try to reconnect.
  201. if not self.disconnected and self.retry < GatewayClient.instance().gateway_retry_max:
  202. jitter = random.randint(10, 100) * 0.01 # noqa: S311
  203. retry_interval = (
  204. min(
  205. GatewayClient.instance().gateway_retry_interval * (2**self.retry),
  206. GatewayClient.instance().gateway_retry_interval_max,
  207. )
  208. + jitter
  209. )
  210. self.retry += 1
  211. self.log.info(
  212. "Attempting to re-establish the connection to Gateway in %s secs (%s/%s): %s",
  213. retry_interval,
  214. self.retry,
  215. GatewayClient.instance().gateway_retry_max,
  216. self.kernel_id,
  217. )
  218. await asyncio.sleep(retry_interval)
  219. loop = IOLoop.current()
  220. loop.spawn_callback(self._connect, self.kernel_id, callback)
  221. def on_open(self, kernel_id, message_callback, **kwargs):
  222. """Web socket connection open against gateway server."""
  223. loop = IOLoop.current()
  224. loop.spawn_callback(self._connect, kernel_id, message_callback)
  225. def on_message(self, message):
  226. """Send message to gateway server."""
  227. if self.ws is None:
  228. loop = IOLoop.current()
  229. loop.add_future(self.ws_future, lambda future: self._write_message(message))
  230. else:
  231. self._write_message(message)
  232. def _write_message(self, message):
  233. """Send message to gateway server."""
  234. try:
  235. if not self.disconnected and self.ws is not None:
  236. self.ws.write_message(message)
  237. except Exception as e:
  238. self.log.error(f"Exception writing message to websocket: {e}") # , exc_info=True)
  239. def on_close(self):
  240. """Web socket closed event."""
  241. self._disconnect()
  242. class GatewayResourceHandler(APIHandler):
  243. """Retrieves resources for specific kernelspec definitions from kernel/enterprise gateway."""
  244. @web.authenticated
  245. async def get(self, kernel_name, path, include_body=True):
  246. """Get a gateway resource by name and path."""
  247. mimetype: Optional[str] = None
  248. ksm = self.kernel_spec_manager
  249. kernel_spec_res = await ksm.get_kernel_spec_resource( # type:ignore[attr-defined]
  250. kernel_name, path
  251. )
  252. if kernel_spec_res is None:
  253. self.log.warning(
  254. f"Kernelspec resource '{path}' for '{kernel_name}' not found. Gateway may not support"
  255. " resource serving."
  256. )
  257. else:
  258. mimetype = mimetypes.guess_type(path)[0] or "text/plain"
  259. self.finish(kernel_spec_res, set_content_type=mimetype)
  260. from ..services.kernels.handlers import _kernel_id_regex
  261. from ..services.kernelspecs.handlers import kernel_name_regex
  262. default_handlers = [
  263. (r"/api/kernels/%s/channels" % _kernel_id_regex, WebSocketChannelsHandler),
  264. (r"/kernelspecs/%s/(?P<path>.*)" % kernel_name_regex, GatewayResourceHandler),
  265. ]