| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- """Kernel connection helpers."""
- import json
- import struct
- from typing import Any
- from jupyter_client.session import Session
- from tornado.websocket import WebSocketHandler
- from traitlets import Float, Instance, Unicode, default
- from traitlets.config import LoggingConfigurable
- try:
- from jupyter_client.jsonutil import json_default
- except ImportError:
- from jupyter_client.jsonutil import date_default as json_default
- from jupyter_client.jsonutil import extract_dates
- from jupyter_server.transutils import _i18n
- from .abc import KernelWebsocketConnectionABC
- def serialize_binary_message(msg):
- """serialize a message as a binary blob
- Header:
- 4 bytes: number of msg parts (nbufs) as 32b int
- 4 * nbufs bytes: offset for each buffer as integer as 32b int
- Offsets are from the start of the buffer, including the header.
- Returns
- -------
- The message serialized to bytes.
- """
- # don't modify msg or buffer list in-place
- msg = msg.copy()
- buffers = list(msg.pop("buffers"))
- bmsg = json.dumps(msg, default=json_default).encode("utf8")
- buffers.insert(0, bmsg)
- nbufs = len(buffers)
- offsets = [4 * (nbufs + 1)]
- for buf in buffers[:-1]:
- offsets.append(offsets[-1] + len(buf))
- offsets_buf = struct.pack("!" + "I" * (nbufs + 1), nbufs, *offsets)
- buffers.insert(0, offsets_buf)
- return b"".join(buffers)
- def deserialize_binary_message(bmsg):
- """deserialize a message from a binary blog
- Header:
- 4 bytes: number of msg parts (nbufs) as 32b int
- 4 * nbufs bytes: offset for each buffer as integer as 32b int
- Offsets are from the start of the buffer, including the header.
- Returns
- -------
- message dictionary
- """
- nbufs = struct.unpack("!i", bmsg[:4])[0]
- offsets = list(struct.unpack("!" + "I" * nbufs, bmsg[4 : 4 * (nbufs + 1)]))
- offsets.append(None)
- bufs = []
- for start, stop in zip(offsets[:-1], offsets[1:]):
- bufs.append(bmsg[start:stop])
- msg = json.loads(bufs[0].decode("utf8"))
- msg["header"] = extract_dates(msg["header"])
- msg["parent_header"] = extract_dates(msg["parent_header"])
- msg["buffers"] = bufs[1:]
- return msg
- def serialize_msg_to_ws_v1(msg_or_list, channel, pack=None):
- """Serialize a message using the v1 protocol."""
- if pack:
- msg_list = [
- pack(msg_or_list["header"]),
- pack(msg_or_list["parent_header"]),
- pack(msg_or_list["metadata"]),
- pack(msg_or_list["content"]),
- ]
- else:
- msg_list = msg_or_list
- channel = channel.encode("utf-8")
- offsets: list[Any] = []
- offsets.append(8 * (1 + 1 + len(msg_list) + 1))
- offsets.append(len(channel) + offsets[-1])
- for msg in msg_list:
- offsets.append(len(msg) + offsets[-1])
- offset_number = len(offsets).to_bytes(8, byteorder="little")
- offsets = [offset.to_bytes(8, byteorder="little") for offset in offsets]
- bin_msg = b"".join([offset_number, *offsets, channel, *msg_list])
- return bin_msg
- def deserialize_msg_from_ws_v1(ws_msg):
- """Deserialize a message using the v1 protocol."""
- offset_number = int.from_bytes(ws_msg[:8], "little")
- offsets = [
- int.from_bytes(ws_msg[8 * (i + 1) : 8 * (i + 2)], "little") for i in range(offset_number)
- ]
- channel = ws_msg[offsets[0] : offsets[1]].decode("utf-8")
- msg_list = [ws_msg[offsets[i] : offsets[i + 1]] for i in range(1, offset_number - 1)]
- return channel, msg_list
- class BaseKernelWebsocketConnection(LoggingConfigurable):
- """A configurable base class for connecting Kernel WebSockets to ZMQ sockets."""
- kernel_ws_protocol = Unicode(
- None,
- allow_none=True,
- config=True,
- help=_i18n(
- "Preferred kernel message protocol over websocket to use (default: None). "
- "If an empty string is passed, select the legacy protocol. If None, "
- "the selected protocol will depend on what the front-end supports "
- "(usually the most recent protocol supported by the back-end and the "
- "front-end)."
- ),
- )
- @property
- def kernel_manager(self):
- """The kernel manager."""
- return self.parent
- @property
- def multi_kernel_manager(self):
- """The multi kernel manager."""
- return self.kernel_manager.parent
- @property
- def kernel_id(self):
- """The kernel id."""
- return self.kernel_manager.kernel_id
- @property
- def session_id(self):
- """The session id."""
- return self.session.session
- kernel_info_timeout = Float()
- @default("kernel_info_timeout")
- def _default_kernel_info_timeout(self):
- return self.multi_kernel_manager.kernel_info_timeout
- session = Instance(klass=Session, config=True)
- @default("session")
- def _default_session(self):
- return Session(config=self.config)
- websocket_handler = Instance(WebSocketHandler)
- async def connect(self):
- """Handle a connect."""
- raise NotImplementedError
- async def disconnect(self):
- """Handle a disconnect."""
- raise NotImplementedError
- def handle_incoming_message(self, incoming_msg: str) -> None:
- """Handle an incoming message."""
- raise NotImplementedError
- def handle_outgoing_message(self, stream: str, outgoing_msg: list[Any]) -> None:
- """Handle an outgoing message."""
- raise NotImplementedError
- KernelWebsocketConnectionABC.register(BaseKernelWebsocketConnection)
|