connections.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. """Gateway connection classes."""
  2. # Copyright (c) Jupyter Development Team.
  3. # Distributed under the terms of the Modified BSD License.
  4. from __future__ import annotations
  5. import asyncio
  6. import logging
  7. import random
  8. from typing import Any, cast
  9. import tornado.websocket as tornado_websocket
  10. from tornado.concurrent import Future
  11. from tornado.escape import json_decode, url_escape, utf8
  12. from tornado.httpclient import HTTPRequest
  13. from tornado.ioloop import IOLoop
  14. from traitlets import Bool, Instance, Int, Unicode
  15. from ..services.kernels.connection.base import BaseKernelWebsocketConnection
  16. from ..utils import url_path_join
  17. from .gateway_client import GatewayClient
  18. class GatewayWebSocketConnection(BaseKernelWebsocketConnection):
  19. """Web socket connection that proxies to a kernel/enterprise gateway."""
  20. ws = Instance(klass=tornado_websocket.WebSocketClientConnection, allow_none=True)
  21. ws_future = Instance(klass=Future, allow_none=True)
  22. disconnected = Bool(False)
  23. retry = Int(0)
  24. # When opening ws connection to gateway, server already negotiated subprotocol with notebook client.
  25. # Same protocol must be used for client and gateway, so legacy ws subprotocol for client is enforced here.
  26. kernel_ws_protocol = Unicode("", allow_none=True, config=True)
  27. async def connect(self):
  28. """Connect to the socket."""
  29. # websocket is initialized before connection
  30. self.ws = None
  31. ws_url = url_path_join(
  32. GatewayClient.instance().ws_url or "",
  33. GatewayClient.instance().kernels_endpoint,
  34. url_escape(self.kernel_id),
  35. "channels",
  36. )
  37. if self.session_id:
  38. ws_url += f"?session_id={url_escape(self.session_id)}"
  39. self.log.info(f"Connecting to {ws_url}")
  40. kwargs: dict[str, Any] = {}
  41. kwargs = GatewayClient.instance().load_connection_args(**kwargs)
  42. request = HTTPRequest(ws_url, **kwargs)
  43. self.ws_future = cast("Future[Any]", tornado_websocket.websocket_connect(request))
  44. self.ws_future.add_done_callback(self._connection_done)
  45. loop = IOLoop.current()
  46. loop.add_future(self.ws_future, lambda future: self._read_messages())
  47. def _connection_done(self, fut):
  48. """Handle a finished connection."""
  49. if (
  50. not self.disconnected and fut.exception() is None
  51. ): # prevent concurrent.futures._base.CancelledError
  52. self.ws = fut.result()
  53. self.retry = 0
  54. self.log.debug(f"Connection is ready: ws: {self.ws}")
  55. else:
  56. self.log.warning(
  57. "Websocket connection has been closed via client disconnect or due to error. "
  58. f"Kernel with ID '{self.kernel_id}' may not be terminated on GatewayClient: {GatewayClient.instance().url}"
  59. )
  60. def disconnect(self):
  61. """Handle a disconnect."""
  62. self.disconnected = True
  63. if self.ws is not None:
  64. # Close connection
  65. self.ws.close()
  66. elif self.ws_future and not self.ws_future.done():
  67. # Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally
  68. self.ws_future.cancel()
  69. self.log.debug(f"_disconnect: future cancelled, disconnected: {self.disconnected}")
  70. async def _read_messages(self):
  71. """Read messages from gateway server."""
  72. while self.ws is not None:
  73. message = None
  74. if not self.disconnected:
  75. try:
  76. message = await self.ws.read_message()
  77. except Exception as e:
  78. self.log.error(
  79. f"Exception reading message from websocket: {e}"
  80. ) # , exc_info=True)
  81. if message is None:
  82. if not self.disconnected:
  83. self.log.warning(f"Lost connection to Gateway: {self.kernel_id}")
  84. break
  85. if isinstance(message, bytes):
  86. message = message.decode("utf8")
  87. self.handle_outgoing_message(
  88. message
  89. ) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open)
  90. else: # ws cancelled - stop reading
  91. break
  92. # NOTE(esevan): if websocket is not disconnected by client, try to reconnect.
  93. if not self.disconnected and self.retry < GatewayClient.instance().gateway_retry_max:
  94. jitter = random.randint(10, 100) * 0.01 # noqa: S311
  95. retry_interval = (
  96. min(
  97. GatewayClient.instance().gateway_retry_interval * (2**self.retry),
  98. GatewayClient.instance().gateway_retry_interval_max,
  99. )
  100. + jitter
  101. )
  102. self.retry += 1
  103. self.log.info(
  104. "Attempting to re-establish the connection to Gateway in %s secs (%s/%s): %s",
  105. retry_interval,
  106. self.retry,
  107. GatewayClient.instance().gateway_retry_max,
  108. self.kernel_id,
  109. )
  110. await asyncio.sleep(retry_interval)
  111. loop = IOLoop.current()
  112. loop.spawn_callback(self.connect)
  113. def handle_outgoing_message(self, incoming_msg: str, *args: Any) -> None:
  114. """Send message to the notebook client."""
  115. try:
  116. self.websocket_handler.write_message(incoming_msg)
  117. except tornado_websocket.WebSocketClosedError:
  118. if self.log.isEnabledFor(logging.DEBUG):
  119. msg_summary = GatewayWebSocketConnection._get_message_summary(
  120. json_decode(utf8(incoming_msg))
  121. )
  122. self.log.debug(
  123. f"Notebook client closed websocket connection - message dropped: {msg_summary}"
  124. )
  125. def handle_incoming_message(self, message: str) -> None:
  126. """Send message to gateway server."""
  127. if self.ws is None and self.ws_future is not None:
  128. loop = IOLoop.current()
  129. loop.add_future(self.ws_future, lambda future: self.handle_incoming_message(message))
  130. else:
  131. self._write_message(message)
  132. def _write_message(self, message):
  133. """Send message to gateway server."""
  134. try:
  135. if not self.disconnected and self.ws is not None:
  136. self.ws.write_message(message)
  137. except Exception as e:
  138. self.log.error(f"Exception writing message to websocket: {e}") # , exc_info=True)
  139. @staticmethod
  140. def _get_message_summary(message):
  141. """Get a summary of a message."""
  142. summary = []
  143. message_type = message["msg_type"]
  144. summary.append(f"type: {message_type}")
  145. if message_type == "status":
  146. summary.append(", state: {}".format(message["content"]["execution_state"]))
  147. elif message_type == "error":
  148. summary.append(
  149. ", {}:{}:{}".format(
  150. message["content"]["ename"],
  151. message["content"]["evalue"],
  152. message["content"]["traceback"],
  153. )
  154. )
  155. else:
  156. summary.append(", ...") # don't display potentially sensitive data
  157. return "".join(summary)