| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111 |
- """Session object for building, serializing, sending, and receiving messages.
- The Session object supports serialization, HMAC signatures,
- and metadata on messages.
- Also defined here are utilities for working with Sessions:
- * A SessionFactory to be used as a base class for configurables that work with
- Sessions.
- * A Message object for convenience that allows attribute-access to the msg dict.
- """
- # Copyright (c) Jupyter Development Team.
- # Distributed under the terms of the Modified BSD License.
- from __future__ import annotations
- import functools
- import hashlib
- import hmac
- import json
- import logging
- import os
- import pickle
- import pprint
- import random
- import typing as t
- import warnings
- from binascii import b2a_hex
- from datetime import datetime, timezone
- from hmac import compare_digest
- # We are using compare_digest to limit the surface of timing attacks
- import zmq.asyncio
- from tornado.ioloop import IOLoop
- from traitlets import (
- Any,
- Bool,
- Callable,
- CBytes,
- CUnicode,
- Dict,
- DottedObjectName,
- Instance,
- Integer,
- Set,
- TraitError,
- Unicode,
- observe,
- )
- from traitlets.config.configurable import Configurable, LoggingConfigurable
- from traitlets.log import get_logger
- from traitlets.utils.importstring import import_item
- from zmq.eventloop.zmqstream import ZMQStream
- from ._version import protocol_version
- from .adapter import adapt
- from .jsonutil import extract_dates, json_clean, json_default, squash_dates
- PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
- utc = timezone.utc
- # -----------------------------------------------------------------------------
- # utility functions
- # -----------------------------------------------------------------------------
- def squash_unicode(obj: t.Any) -> t.Any:
- """coerce unicode back to bytestrings."""
- if isinstance(obj, dict):
- for key in list(obj.keys()):
- obj[key] = squash_unicode(obj[key])
- if isinstance(key, str):
- obj[squash_unicode(key)] = obj.pop(key)
- elif isinstance(obj, list):
- for i, v in enumerate(obj):
- obj[i] = squash_unicode(v)
- elif isinstance(obj, str):
- obj = obj.encode("utf8")
- return obj
- # -----------------------------------------------------------------------------
- # globals and defaults
- # -----------------------------------------------------------------------------
- # default values for the thresholds:
- MAX_ITEMS = 64
- MAX_BYTES = 1024
- # ISO8601-ify datetime objects
- # allow unicode
- # disallow nan, because it's not actually valid JSON
- def json_packer(obj: t.Any) -> bytes:
- """Convert a json object to a bytes."""
- try:
- return json.dumps(
- obj,
- default=json_default,
- ensure_ascii=False,
- allow_nan=False,
- ).encode("utf8", errors="surrogateescape")
- except (TypeError, ValueError) as e:
- # Fallback to trying to clean the json before serializing
- packed = json.dumps(
- json_clean(obj),
- default=json_default,
- ensure_ascii=False,
- allow_nan=False,
- ).encode("utf8", errors="surrogateescape")
- warnings.warn(
- f"Message serialization failed with:\n{e}\n"
- "Supporting this message is deprecated in jupyter-client 7, please make "
- "sure your message is JSON-compliant",
- stacklevel=2,
- )
- return packed
- def json_unpacker(s: str | bytes) -> t.Any:
- """Convert a json bytes or string to an object."""
- if isinstance(s, bytes):
- s = s.decode("utf8", "replace")
- return json.loads(s)
- try:
- import orjson
- except ModuleNotFoundError:
- has_orjson = False
- orjson_packer, orjson_unpacker = json_packer, json_unpacker
- else:
- has_orjson = True
- def orjson_packer(
- obj: t.Any, *, option: int | None = orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z
- ) -> bytes:
- """Convert a json object to a bytes using orjson with fallback to json_packer."""
- try:
- return orjson.dumps(obj, default=json_default, option=option)
- except Exception:
- return json_packer(obj)
- def orjson_unpacker(s: str | bytes) -> t.Any:
- """Convert a json bytes or string to an object using orjson with fallback to json_unpacker."""
- try:
- return orjson.loads(s)
- except Exception:
- return json_unpacker(s)
- try:
- import msgpack
- except ModuleNotFoundError:
- has_msgpack = False
- else:
- has_msgpack = True
- msgpack_packer = functools.partial(msgpack.packb, default=json_default)
- msgpack_unpacker = msgpack.unpackb
- def pickle_packer(o: t.Any) -> bytes:
- """Pack an object using the pickle module."""
- return pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
- pickle_unpacker = pickle.loads
- DELIM = b"<IDS|MSG>"
- # singleton dummy tracker, which will always report as done
- DONE = zmq.MessageTracker()
- # -----------------------------------------------------------------------------
- # Mixin tools for apps that use Sessions
- # -----------------------------------------------------------------------------
- def new_id() -> str:
- """Generate a new random id.
- Avoids problematic runtime import in stdlib uuid on Python 2.
- Returns
- -------
- id string (16 random bytes as hex-encoded text, chunks separated by '-')
- """
- buf = os.urandom(16)
- return "-".join(b2a_hex(x).decode("ascii") for x in (buf[:4], buf[4:]))
- def new_id_bytes() -> bytes:
- """Return new_id as ascii bytes"""
- return new_id().encode("ascii")
- session_aliases = {
- "ident": "Session.session",
- "user": "Session.username",
- "keyfile": "Session.keyfile",
- }
- session_flags = {
- "secure": (
- {"Session": {"key": new_id_bytes(), "keyfile": ""}},
- """Use HMAC digests for authentication of messages.
- Setting this flag will generate a new UUID to use as the HMAC key.
- """,
- ),
- "no-secure": (
- {"Session": {"key": b"", "keyfile": ""}},
- """Don't authenticate messages.""",
- ),
- }
- def default_secure(cfg: t.Any) -> None: # pragma: no cover
- """Set the default behavior for a config environment to be secure.
- If Session.key/keyfile have not been set, set Session.key to
- a new random UUID.
- """
- warnings.warn("default_secure is deprecated", DeprecationWarning, stacklevel=2)
- if "Session" in cfg and ("key" in cfg.Session or "keyfile" in cfg.Session):
- return
- # key/keyfile not specified, generate new UUID:
- cfg.Session.key = new_id_bytes()
- def utcnow() -> datetime:
- """Return timezone-aware UTC timestamp"""
- return datetime.now(utc)
- # -----------------------------------------------------------------------------
- # Classes
- # -----------------------------------------------------------------------------
- class SessionFactory(LoggingConfigurable):
- """The Base class for configurables that have a Session, Context, logger,
- and IOLoop.
- """
- logname = Unicode("")
- @observe("logname")
- def _logname_changed(self, change: t.Any) -> None:
- self.log = logging.getLogger(change["new"])
- # not configurable:
- context = Instance("zmq.Context")
- def _context_default(self) -> zmq.Context:
- return zmq.Context()
- session = Instance("jupyter_client.session.Session", allow_none=True)
- loop = Instance("tornado.ioloop.IOLoop")
- def _loop_default(self) -> IOLoop:
- return IOLoop.current()
- def __init__(self, **kwargs: t.Any) -> None:
- """Initialize a session factory."""
- super().__init__(**kwargs)
- if self.session is None:
- # construct the session
- self.session = Session(**kwargs)
- class Message:
- """A simple message object that maps dict keys to attributes.
- A Message can be created from a dict and a dict from a Message instance
- simply by calling dict(msg_obj)."""
- def __init__(self, msg_dict: dict[str, t.Any]) -> None:
- """Initialize a message."""
- dct = self.__dict__
- for k, v in dict(msg_dict).items():
- if isinstance(v, dict):
- v = Message(v) # noqa
- dct[k] = v
- # Having this iterator lets dict(msg_obj) work out of the box.
- def __iter__(self) -> t.ItemsView[str, t.Any]:
- return iter(self.__dict__.items()) # type:ignore[return-value]
- def __repr__(self) -> str:
- return repr(self.__dict__)
- def __str__(self) -> str:
- return pprint.pformat(self.__dict__)
- def __contains__(self, k: object) -> bool:
- return k in self.__dict__
- def __getitem__(self, k: str) -> t.Any:
- return self.__dict__[k]
- def msg_header(
- msg_id: str, msg_type: str, username: str, session: Session | str
- ) -> dict[str, t.Any]:
- """Create a new message header"""
- date = utcnow()
- version = protocol_version
- return locals()
- def extract_header(msg_or_header: dict[str, t.Any]) -> dict[str, t.Any]:
- """Given a message or header, return the header."""
- if not msg_or_header:
- return {}
- try:
- # See if msg_or_header is the entire message.
- h = msg_or_header["header"]
- except KeyError:
- try:
- # See if msg_or_header is just the header
- h = msg_or_header["msg_id"]
- except KeyError:
- raise
- else:
- h = msg_or_header
- if not isinstance(h, dict):
- h = dict(h)
- return h
- class Session(Configurable):
- """Object for handling serialization and sending of messages.
- The Session object handles building messages and sending them
- with ZMQ sockets or ZMQStream objects. Objects can communicate with each
- other over the network via Session objects, and only need to work with the
- dict-based IPython message spec. The Session will handle
- serialization/deserialization, security, and metadata.
- Sessions support configurable serialization via packer/unpacker traits,
- and signing with HMAC digests via the key/keyfile traits.
- Parameters
- ----------
- debug : bool
- whether to trigger extra debugging statements
- packer/unpacker : str : 'orjson', 'json', 'pickle', 'msgpack' or import_string
- importstrings for methods to serialize message parts. If just
- 'json' or 'pickle', predefined JSON and pickle packers will be used.
- Otherwise, the entire importstring must be used.
- The functions must accept at least valid JSON input, and output *bytes*.
- For example, to use msgpack:
- packer = 'msgpack.packb', unpacker='msgpack.unpackb'
- pack/unpack : callables
- You can also set the pack/unpack callables for serialization directly.
- session : bytes
- the ID of this Session object. The default is to generate a new UUID.
- username : unicode
- username added to message headers. The default is to ask the OS.
- key : bytes
- The key used to initialize an HMAC signature. If unset, messages
- will not be signed or checked.
- keyfile : filepath
- The file containing a key. If this is set, `key` will be initialized
- to the contents of the file.
- """
- debug = Bool(False, config=True, help="""Debug output in the Session""")
- check_pid = Bool(
- True,
- config=True,
- help="""Whether to check PID to protect against calls after fork.
- This check can be disabled if fork-safety is handled elsewhere.
- """,
- )
- # serialization traits:
- packer = DottedObjectName(
- "orjson" if has_orjson else "json",
- config=True,
- help="""The name of the packer for serializing messages.
- Should be one of 'json', 'pickle', or an import name
- for a custom callable serializer.""",
- )
- unpacker = DottedObjectName(
- "orjson" if has_orjson else "json",
- config=True,
- help="""The name of the unpacker for unserializing messages.
- Only used with custom functions for `packer`.""",
- )
- pack = Callable(orjson_packer if has_orjson else json_packer) # the actual packer function
- unpack = Callable(
- orjson_unpacker if has_orjson else json_unpacker
- ) # the actual unpacker function
- @observe("packer", "unpacker")
- def _packer_unpacker_changed(self, change: t.Any) -> None:
- new = change["new"].lower()
- if new == "orjson" and has_orjson:
- self.pack, self.unpack = orjson_packer, orjson_unpacker
- elif new == "json" or new == "orjson":
- self.pack, self.unpack = json_packer, json_unpacker
- elif new == "pickle":
- self.pack, self.unpack = pickle_packer, pickle_unpacker
- elif new == "msgpack" and has_msgpack:
- self.pack, self.unpack = msgpack_packer, msgpack_unpacker
- else:
- obj = import_item(str(change["new"]))
- name = "pack" if change["name"] == "packer" else "unpack"
- self.set_trait(name, obj)
- return
- self.packer = self.unpacker = change["new"]
- session = CUnicode("", config=True, help="""The UUID identifying this session.""")
- def _session_default(self) -> str:
- u = new_id()
- self.bsession = u.encode("ascii")
- return u
- @observe("session")
- def _session_changed(self, change: t.Any) -> None:
- self.bsession = self.session.encode("ascii")
- # bsession is the session as bytes
- bsession = CBytes(b"")
- username = Unicode(
- os.environ.get("USER", "username"),
- help="""Username for the Session. Default is your system username.""",
- config=True,
- )
- metadata = Dict(
- {},
- config=True,
- help="Metadata dictionary, which serves as the default top-level metadata dict for each message.",
- )
- # if 0, no adapting to do.
- adapt_version = Integer(0)
- # message signature related traits:
- key = CBytes(config=True, help="""execution key, for signing messages.""")
- def _key_default(self) -> bytes:
- return new_id_bytes()
- @observe("key")
- def _key_changed(self, change: t.Any) -> None:
- self._new_auth()
- signature_scheme = Unicode(
- "hmac-sha256",
- config=True,
- help="""The digest scheme used to construct the message signatures.
- Must have the form 'hmac-HASH'.""",
- )
- @observe("signature_scheme")
- def _signature_scheme_changed(self, change: t.Any) -> None:
- new = change["new"]
- if not new.startswith("hmac-"):
- raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
- hash_name = new.split("-", 1)[1]
- try:
- self.digest_mod = getattr(hashlib, hash_name)
- except AttributeError as e:
- raise TraitError("hashlib has no such attribute: %s" % hash_name) from e
- self._new_auth()
- digest_mod = Any()
- def _digest_mod_default(self) -> t.Callable:
- return hashlib.sha256
- auth = Instance(hmac.HMAC, allow_none=True)
- def _new_auth(self) -> None:
- if self.key:
- self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
- else:
- self.auth = None
- digest_history = Set()
- digest_history_size = Integer(
- 2**16,
- config=True,
- help="""The maximum number of digests to remember.
- The digest history will be culled when it exceeds this value.
- """,
- )
- keyfile = Unicode("", config=True, help="""path to file containing execution key.""")
- @observe("keyfile")
- def _keyfile_changed(self, change: t.Any) -> None:
- with open(change["new"], "rb") as f:
- self.key = f.read().strip()
- # for protecting against sends from forks
- pid = Integer()
- # thresholds:
- copy_threshold = Integer(
- 2**16,
- config=True,
- help="Threshold (in bytes) beyond which a buffer should be sent without copying.",
- )
- buffer_threshold = Integer(
- MAX_BYTES,
- config=True,
- help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.",
- )
- item_threshold = Integer(
- MAX_ITEMS,
- config=True,
- help="""The maximum number of items for a container to be introspected for custom serialization.
- Containers larger than this are pickled outright.
- """,
- )
- def __init__(self, **kwargs: t.Any) -> None:
- """create a Session object
- Parameters
- ----------
- debug : bool
- whether to trigger extra debugging statements
- packer/unpacker : str : 'orjson', 'json', 'pickle', 'msgpack' or import_string
- importstrings for methods to serialize message parts. If just
- 'json' or 'pickle', predefined JSON and pickle packers will be used.
- Otherwise, the entire importstring must be used.
- The functions must accept at least valid JSON input, and output
- *bytes*.
- For example, to use msgpack:
- packer = 'msgpack.packb', unpacker='msgpack.unpackb'
- pack/unpack : callables
- You can also set the pack/unpack callables for serialization
- directly.
- session : unicode (must be ascii)
- the ID of this Session object. The default is to generate a new
- UUID.
- bsession : bytes
- The session as bytes
- username : unicode
- username added to message headers. The default is to ask the OS.
- key : bytes
- The key used to initialize an HMAC signature. If unset, messages
- will not be signed or checked.
- signature_scheme : str
- The message digest scheme. Currently must be of the form 'hmac-HASH',
- where 'HASH' is a hashing function available in Python's hashlib.
- The default is 'hmac-sha256'.
- This is ignored if 'key' is empty.
- keyfile : filepath
- The file containing a key. If this is set, `key` will be
- initialized to the contents of the file.
- """
- super().__init__(**kwargs)
- self._check_packers()
- self.none = self.pack({})
- # ensure self._session_default() if necessary, so bsession is defined:
- self.session # noqa
- self.pid = os.getpid()
- self._new_auth()
- if not self.key:
- get_logger().warning(
- "Message signing is disabled. This is insecure and not recommended!"
- )
- def clone(self) -> Session:
- """Create a copy of this Session
- Useful when connecting multiple times to a given kernel.
- This prevents a shared digest_history warning about duplicate digests
- due to multiple connections to IOPub in the same process.
- .. versionadded:: 5.1
- """
- # make a copy
- new_session = type(self)()
- for name in self.traits():
- setattr(new_session, name, getattr(self, name))
- # fork digest_history
- new_session.digest_history = set()
- new_session.digest_history.update(self.digest_history)
- return new_session
- message_count = 0
- @property
- def msg_id(self) -> str:
- message_number = self.message_count
- self.message_count += 1
- return f"{self.session}_{os.getpid()}_{message_number}"
- def _check_packers(self) -> None:
- """check packers for datetime support."""
- pack = self.pack
- unpack = self.unpack
- # check simple serialization
- msg_list = {"a": [1, "hi"]}
- try:
- packed = pack(msg_list)
- except Exception as e:
- msg = f"packer '{self.packer}' could not serialize a simple message: {e}"
- raise ValueError(msg) from e
- # ensure packed message is bytes
- if not isinstance(packed, bytes):
- raise ValueError("message packed to %r, but bytes are required" % type(packed))
- # check that unpack is pack's inverse
- try:
- unpacked = unpack(packed)
- assert unpacked == msg_list
- except Exception as e:
- msg = f"unpacker {self.unpacker!r} could not handle output from packer {self.packer!r}: {e}"
- raise ValueError(msg) from e
- # check datetime support
- msg_datetime = {"t": utcnow()}
- try:
- unpacked = unpack(pack(msg_datetime))
- if isinstance(unpacked["t"], datetime):
- msg = "Shouldn't deserialize to datetime"
- raise ValueError(msg)
- except Exception:
- self.pack = lambda o: pack(squash_dates(o))
- self.unpack = lambda s: unpack(s)
- def msg_header(self, msg_type: str) -> dict[str, t.Any]:
- """Create a header for a message type."""
- return msg_header(self.msg_id, msg_type, self.username, self.session)
- def msg(
- self,
- msg_type: str,
- content: dict | None = None,
- parent: dict[str, t.Any] | None = None,
- header: dict[str, t.Any] | None = None,
- metadata: dict[str, t.Any] | None = None,
- ) -> dict[str, t.Any]:
- """Return the nested message dict.
- This format is different from what is sent over the wire. The
- serialize/deserialize methods converts this nested message dict to the wire
- format, which is a list of message parts.
- """
- msg = {}
- header = self.msg_header(msg_type) if header is None else header
- msg["header"] = header
- msg["msg_id"] = header["msg_id"]
- msg["msg_type"] = header["msg_type"]
- msg["parent_header"] = {} if parent is None else extract_header(parent)
- msg["content"] = {} if content is None else content
- msg["metadata"] = self.metadata.copy()
- if metadata is not None:
- msg["metadata"].update(metadata)
- return msg
- def sign(self, msg_list: list) -> bytes:
- """Sign a message with HMAC digest. If no auth, return b''.
- Parameters
- ----------
- msg_list : list
- The [p_header,p_parent,p_content] part of the message list.
- """
- if self.auth is None:
- return b""
- h = self.auth.copy()
- for m in msg_list:
- h.update(m)
- return h.hexdigest().encode()
- def serialize(
- self,
- msg: dict[str, t.Any],
- ident: list[bytes] | bytes | None = None,
- ) -> list[bytes]:
- """Serialize the message components to bytes.
- This is roughly the inverse of deserialize. The serialize/deserialize
- methods work with full message lists, whereas pack/unpack work with
- the individual message parts in the message list.
- Parameters
- ----------
- msg : dict or Message
- The next message dict as returned by the self.msg method.
- Returns
- -------
- msg_list : list
- The list of bytes objects to be sent with the format::
- [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
- p_metadata, p_content, buffer1, buffer2, ...]
- In this list, the ``p_*`` entities are the packed or serialized
- versions, so if JSON is used, these are utf8 encoded JSON strings.
- """
- content = msg.get("content", {})
- if content is None:
- content = self.none
- elif isinstance(content, dict):
- content = self.pack(content)
- elif isinstance(content, bytes):
- # content is already packed, as in a relayed message
- pass
- elif isinstance(content, str):
- # should be bytes, but JSON often spits out unicode
- content = content.encode("utf8")
- else:
- raise TypeError("Content incorrect type: %s" % type(content))
- real_message = [
- self.pack(msg["header"]),
- self.pack(msg["parent_header"]),
- self.pack(msg["metadata"]),
- content,
- ]
- to_send = []
- if isinstance(ident, list):
- # accept list of idents
- to_send.extend(ident)
- elif ident is not None:
- to_send.append(ident)
- to_send.append(DELIM)
- signature = self.sign(real_message)
- to_send.append(signature)
- to_send.extend(real_message)
- return to_send
- def send(
- self,
- stream: zmq.sugar.socket.Socket | ZMQStream | None,
- msg_or_type: dict[str, t.Any] | str,
- content: dict[str, t.Any] | None = None,
- parent: dict[str, t.Any] | None = None,
- ident: bytes | list[bytes] | None = None,
- buffers: list[bytes | memoryview[bytes]] | None = None,
- track: bool = False,
- header: dict[str, t.Any] | None = None,
- metadata: dict[str, t.Any] | None = None,
- ) -> dict[str, t.Any] | None:
- """Build and send a message via stream or socket.
- The message format used by this function internally is as follows:
- [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
- buffer1,buffer2,...]
- The serialize/deserialize methods convert the nested message dict into this
- format.
- Parameters
- ----------
- stream : zmq.Socket or ZMQStream
- The socket-like object used to send the data.
- msg_or_type : str or Message/dict
- Normally, msg_or_type will be a msg_type unless a message is being
- sent more than once. If a header is supplied, this can be set to
- None and the msg_type will be pulled from the header.
- content : dict or None
- The content of the message (ignored if msg_or_type is a message).
- header : dict or None
- The header dict for the message (ignored if msg_to_type is a message).
- parent : Message or dict or None
- The parent or parent header describing the parent of this message
- (ignored if msg_or_type is a message).
- ident : bytes or list of bytes
- The zmq.IDENTITY routing path.
- metadata : dict or None
- The metadata describing the message
- buffers : list or None
- The already-serialized buffers to be appended to the message.
- track : bool
- Whether to track. Only for use with Sockets, because ZMQStream
- objects cannot track messages.
- Returns
- -------
- msg : dict
- The constructed message.
- """
- if not isinstance(stream, zmq.Socket):
- # ZMQStreams and dummy sockets do not support tracking.
- track = False
- if isinstance(stream, zmq.asyncio.Socket):
- assert stream is not None
- stream = zmq.Socket.shadow(stream.underlying)
- if isinstance(msg_or_type, Message | dict):
- # We got a Message or message dict, not a msg_type so don't
- # build a new Message.
- msg = msg_or_type
- buffers = buffers or msg.get("buffers", [])
- else:
- msg = self.msg(
- msg_or_type,
- content=content,
- parent=parent,
- header=header,
- metadata=metadata,
- )
- if self.check_pid and os.getpid() != self.pid:
- get_logger().warning("WARNING: attempted to send message from fork\n%s", msg)
- return None
- buffers = [] if buffers is None else buffers
- for idx, buf in enumerate(buffers):
- if isinstance(buf, memoryview):
- view = buf
- else:
- try:
- # check to see if buf supports the buffer protocol.
- view = memoryview(buf)
- except TypeError as e:
- emsg = "Buffer objects must support the buffer protocol."
- raise TypeError(emsg) from e
- if not view.contiguous:
- # zmq requires memoryviews to be contiguous
- raise ValueError("Buffer %i (%r) is not contiguous" % (idx, buf))
- if self.adapt_version:
- msg = adapt(msg, self.adapt_version)
- to_send = self.serialize(msg, ident)
- to_send.extend(buffers) # type: ignore[arg-type]
- longest = max([len(s) for s in to_send])
- copy = longest < self.copy_threshold
- if stream and buffers and track and not copy:
- # only really track when we are doing zero-copy buffers
- tracker = stream.send_multipart(to_send, copy=False, track=True)
- elif stream:
- # use dummy tracker, which will be done immediately
- tracker = DONE
- stream.send_multipart(to_send, copy=copy)
- else:
- tracker = DONE
- if self.debug:
- pprint.pprint(msg) # noqa
- pprint.pprint(to_send) # noqa
- pprint.pprint(buffers) # noqa
- msg["tracker"] = tracker
- return msg
- def send_raw(
- self,
- stream: zmq.sugar.socket.Socket,
- msg_list: list,
- flags: int = 0,
- copy: bool = True,
- ident: bytes | list[bytes] | None = None,
- ) -> None:
- """Send a raw message via ident path.
- This method is used to send a already serialized message.
- Parameters
- ----------
- stream : ZMQStream or Socket
- The ZMQ stream or socket to use for sending the message.
- msg_list : list
- The serialized list of messages to send. This only includes the
- [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
- the message.
- ident : ident or list
- A single ident or a list of idents to use in sending.
- """
- to_send = []
- if isinstance(ident, bytes):
- ident = [ident]
- if ident is not None:
- to_send.extend(ident)
- to_send.append(DELIM)
- # Don't include buffers in signature (per spec).
- to_send.append(self.sign(msg_list[0:4]))
- to_send.extend(msg_list)
- if isinstance(stream, zmq.asyncio.Socket):
- stream = zmq.Socket.shadow(stream.underlying)
- stream.send_multipart(to_send, flags, copy=copy)
- def recv(
- self,
- socket: zmq.sugar.socket.Socket,
- mode: int = zmq.NOBLOCK,
- content: bool = True,
- copy: bool = True,
- ) -> tuple[list[bytes] | None, dict[str, t.Any] | None]:
- """Receive and unpack a message.
- Parameters
- ----------
- socket : ZMQStream or Socket
- The socket or stream to use in receiving.
- Returns
- -------
- [idents], msg
- [idents] is a list of idents and msg is a nested message dict of
- same format as self.msg returns.
- """
- if isinstance(socket, ZMQStream): # type:ignore[unreachable]
- socket = socket.socket # type:ignore[unreachable]
- if isinstance(socket, zmq.asyncio.Socket):
- socket = zmq.Socket.shadow(socket.underlying)
- try:
- msg_list = socket.recv_multipart(mode, copy=copy)
- except zmq.ZMQError as e:
- if e.errno == zmq.EAGAIN:
- # We can convert EAGAIN to None as we know in this case
- # recv_multipart won't return None.
- return None, None
- else:
- raise
- # split multipart message into identity list and message dict
- # invalid large messages can cause very expensive string comparisons
- idents, msg_list = self.feed_identities(msg_list, copy)
- try:
- return idents, self.deserialize(msg_list, content=content, copy=copy)
- except Exception as e:
- # TODO: handle it
- raise e
- def feed_identities(
- self, msg_list: list[bytes] | list[zmq.Message], copy: bool = True
- ) -> tuple[list[bytes], list[bytes] | list[zmq.Message]]:
- """Split the identities from the rest of the message.
- Feed until DELIM is reached, then return the prefix as idents and
- remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
- but that would be silly.
- Parameters
- ----------
- msg_list : a list of Message or bytes objects
- The message to be split.
- copy : bool
- flag determining whether the arguments are bytes or Messages
- Returns
- -------
- (idents, msg_list) : two lists
- idents will always be a list of bytes, each of which is a ZMQ
- identity. msg_list will be a list of bytes or zmq.Messages of the
- form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
- should be unpackable/unserializable via self.deserialize at this
- point.
- """
- if copy:
- msg_list = t.cast(t.List[bytes], msg_list)
- idx = msg_list.index(DELIM)
- return msg_list[:idx], msg_list[idx + 1 :]
- else:
- msg_list = t.cast(t.List[zmq.Message], msg_list)
- failed = True
- for idx, m in enumerate(msg_list): # noqa
- if m.bytes == DELIM:
- failed = False
- break
- if failed:
- msg = "DELIM not in msg_list"
- raise ValueError(msg)
- idents, msg_list = msg_list[:idx], msg_list[idx + 1 :]
- return [bytes(m.bytes) for m in idents], msg_list
- def _add_digest(self, signature: bytes) -> None:
- """add a digest to history to protect against replay attacks"""
- if self.digest_history_size == 0:
- # no history, never add digests
- return
- self.digest_history.add(signature)
- if len(self.digest_history) > self.digest_history_size:
- # threshold reached, cull 10%
- self._cull_digest_history()
- def _cull_digest_history(self) -> None:
- """cull the digest history
- Removes a randomly selected 10% of the digest history
- """
- current = len(self.digest_history)
- n_to_cull = max(int(current // 10), current - self.digest_history_size)
- if n_to_cull >= current:
- self.digest_history = set()
- return
- to_cull = random.sample(tuple(sorted(self.digest_history)), n_to_cull)
- self.digest_history.difference_update(to_cull)
- def deserialize(
- self,
- msg_list: list[bytes] | list[zmq.Message],
- content: bool = True,
- copy: bool = True,
- ) -> dict[str, t.Any]:
- """Unserialize a msg_list to a nested message dict.
- This is roughly the inverse of serialize. The serialize/deserialize
- methods work with full message lists, whereas pack/unpack work with
- the individual message parts in the message list.
- Parameters
- ----------
- msg_list : list of bytes or Message objects
- The list of message parts of the form [HMAC,p_header,p_parent,
- p_metadata,p_content,buffer1,buffer2,...].
- content : bool (True)
- Whether to unpack the content dict (True), or leave it packed
- (False).
- copy : bool (True)
- Whether msg_list contains bytes (True) or the non-copying Message
- objects in each place (False).
- Returns
- -------
- msg : dict
- The nested message dict with top-level keys [header, parent_header,
- content, buffers]. The buffers are returned as memoryviews.
- """
- minlen = 5
- message = {}
- if not copy:
- # pyzmq didn't copy the first parts of the message, so we'll do it
- msg_list = t.cast(t.List[zmq.Message], msg_list)
- msg_list_beginning = [bytes(msg.bytes) for msg in msg_list[:minlen]]
- msg_list = t.cast(t.List[bytes], msg_list)
- msg_list = msg_list_beginning + msg_list[minlen:]
- msg_list = t.cast(t.List[bytes], msg_list)
- if self.auth is not None:
- signature = msg_list[0]
- if not signature:
- msg = "Unsigned Message"
- raise ValueError(msg)
- if signature in self.digest_history:
- raise ValueError("Duplicate Signature: %r" % signature)
- if content:
- # Only store signature if we are unpacking content, don't store if just peeking.
- self._add_digest(signature)
- check = self.sign(msg_list[1:5])
- if not compare_digest(signature, check):
- msg = "Invalid Signature: %r" % signature
- raise ValueError(msg)
- if not len(msg_list) >= minlen:
- msg = "malformed message, must have at least %i elements" % minlen
- raise TypeError(msg)
- header = self.unpack(msg_list[1])
- message["header"] = extract_dates(header)
- message["msg_id"] = header["msg_id"]
- message["msg_type"] = header["msg_type"]
- message["parent_header"] = extract_dates(self.unpack(msg_list[2]))
- message["metadata"] = self.unpack(msg_list[3])
- if content:
- message["content"] = self.unpack(msg_list[4])
- else:
- message["content"] = msg_list[4]
- buffers = [memoryview(b) for b in msg_list[5:]]
- if buffers and buffers[0].shape is None:
- # force copy to workaround pyzmq #646
- msg_list = t.cast(t.List[zmq.Message], msg_list)
- buffers = [memoryview(bytes(b.bytes)) for b in msg_list[5:]]
- message["buffers"] = buffers
- if self.debug:
- pprint.pprint(message) # noqa
- # adapt to the current version
- return adapt(message)
- def unserialize(self, *args: t.Any, **kwargs: t.Any) -> dict[str, t.Any]:
- """**DEPRECATED** Use deserialize instead."""
- # pragma: no cover
- warnings.warn(
- "Session.unserialize is deprecated. Use Session.deserialize.",
- DeprecationWarning,
- stacklevel=2,
- )
- return self.deserialize(*args, **kwargs)
|