session.py 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111
  1. """Session object for building, serializing, sending, and receiving messages.
  2. The Session object supports serialization, HMAC signatures,
  3. and metadata on messages.
  4. Also defined here are utilities for working with Sessions:
  5. * A SessionFactory to be used as a base class for configurables that work with
  6. Sessions.
  7. * A Message object for convenience that allows attribute-access to the msg dict.
  8. """
  9. # Copyright (c) Jupyter Development Team.
  10. # Distributed under the terms of the Modified BSD License.
  11. from __future__ import annotations
  12. import functools
  13. import hashlib
  14. import hmac
  15. import json
  16. import logging
  17. import os
  18. import pickle
  19. import pprint
  20. import random
  21. import typing as t
  22. import warnings
  23. from binascii import b2a_hex
  24. from datetime import datetime, timezone
  25. from hmac import compare_digest
  26. # We are using compare_digest to limit the surface of timing attacks
  27. import zmq.asyncio
  28. from tornado.ioloop import IOLoop
  29. from traitlets import (
  30. Any,
  31. Bool,
  32. Callable,
  33. CBytes,
  34. CUnicode,
  35. Dict,
  36. DottedObjectName,
  37. Instance,
  38. Integer,
  39. Set,
  40. TraitError,
  41. Unicode,
  42. observe,
  43. )
  44. from traitlets.config.configurable import Configurable, LoggingConfigurable
  45. from traitlets.log import get_logger
  46. from traitlets.utils.importstring import import_item
  47. from zmq.eventloop.zmqstream import ZMQStream
  48. from ._version import protocol_version
  49. from .adapter import adapt
  50. from .jsonutil import extract_dates, json_clean, json_default, squash_dates
  51. PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
  52. utc = timezone.utc
  53. # -----------------------------------------------------------------------------
  54. # utility functions
  55. # -----------------------------------------------------------------------------
  56. def squash_unicode(obj: t.Any) -> t.Any:
  57. """coerce unicode back to bytestrings."""
  58. if isinstance(obj, dict):
  59. for key in list(obj.keys()):
  60. obj[key] = squash_unicode(obj[key])
  61. if isinstance(key, str):
  62. obj[squash_unicode(key)] = obj.pop(key)
  63. elif isinstance(obj, list):
  64. for i, v in enumerate(obj):
  65. obj[i] = squash_unicode(v)
  66. elif isinstance(obj, str):
  67. obj = obj.encode("utf8")
  68. return obj
  69. # -----------------------------------------------------------------------------
  70. # globals and defaults
  71. # -----------------------------------------------------------------------------
  72. # default values for the thresholds:
  73. MAX_ITEMS = 64
  74. MAX_BYTES = 1024
  75. # ISO8601-ify datetime objects
  76. # allow unicode
  77. # disallow nan, because it's not actually valid JSON
  78. def json_packer(obj: t.Any) -> bytes:
  79. """Convert a json object to a bytes."""
  80. try:
  81. return json.dumps(
  82. obj,
  83. default=json_default,
  84. ensure_ascii=False,
  85. allow_nan=False,
  86. ).encode("utf8", errors="surrogateescape")
  87. except (TypeError, ValueError) as e:
  88. # Fallback to trying to clean the json before serializing
  89. packed = json.dumps(
  90. json_clean(obj),
  91. default=json_default,
  92. ensure_ascii=False,
  93. allow_nan=False,
  94. ).encode("utf8", errors="surrogateescape")
  95. warnings.warn(
  96. f"Message serialization failed with:\n{e}\n"
  97. "Supporting this message is deprecated in jupyter-client 7, please make "
  98. "sure your message is JSON-compliant",
  99. stacklevel=2,
  100. )
  101. return packed
  102. def json_unpacker(s: str | bytes) -> t.Any:
  103. """Convert a json bytes or string to an object."""
  104. if isinstance(s, bytes):
  105. s = s.decode("utf8", "replace")
  106. return json.loads(s)
  107. try:
  108. import orjson
  109. except ModuleNotFoundError:
  110. has_orjson = False
  111. orjson_packer, orjson_unpacker = json_packer, json_unpacker
  112. else:
  113. has_orjson = True
  114. def orjson_packer(
  115. obj: t.Any, *, option: int | None = orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z
  116. ) -> bytes:
  117. """Convert a json object to a bytes using orjson with fallback to json_packer."""
  118. try:
  119. return orjson.dumps(obj, default=json_default, option=option)
  120. except Exception:
  121. return json_packer(obj)
  122. def orjson_unpacker(s: str | bytes) -> t.Any:
  123. """Convert a json bytes or string to an object using orjson with fallback to json_unpacker."""
  124. try:
  125. return orjson.loads(s)
  126. except Exception:
  127. return json_unpacker(s)
  128. try:
  129. import msgpack
  130. except ModuleNotFoundError:
  131. has_msgpack = False
  132. else:
  133. has_msgpack = True
  134. msgpack_packer = functools.partial(msgpack.packb, default=json_default)
  135. msgpack_unpacker = msgpack.unpackb
  136. def pickle_packer(o: t.Any) -> bytes:
  137. """Pack an object using the pickle module."""
  138. return pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
  139. pickle_unpacker = pickle.loads
  140. DELIM = b"<IDS|MSG>"
  141. # singleton dummy tracker, which will always report as done
  142. DONE = zmq.MessageTracker()
  143. # -----------------------------------------------------------------------------
  144. # Mixin tools for apps that use Sessions
  145. # -----------------------------------------------------------------------------
  146. def new_id() -> str:
  147. """Generate a new random id.
  148. Avoids problematic runtime import in stdlib uuid on Python 2.
  149. Returns
  150. -------
  151. id string (16 random bytes as hex-encoded text, chunks separated by '-')
  152. """
  153. buf = os.urandom(16)
  154. return "-".join(b2a_hex(x).decode("ascii") for x in (buf[:4], buf[4:]))
  155. def new_id_bytes() -> bytes:
  156. """Return new_id as ascii bytes"""
  157. return new_id().encode("ascii")
  158. session_aliases = {
  159. "ident": "Session.session",
  160. "user": "Session.username",
  161. "keyfile": "Session.keyfile",
  162. }
  163. session_flags = {
  164. "secure": (
  165. {"Session": {"key": new_id_bytes(), "keyfile": ""}},
  166. """Use HMAC digests for authentication of messages.
  167. Setting this flag will generate a new UUID to use as the HMAC key.
  168. """,
  169. ),
  170. "no-secure": (
  171. {"Session": {"key": b"", "keyfile": ""}},
  172. """Don't authenticate messages.""",
  173. ),
  174. }
  175. def default_secure(cfg: t.Any) -> None: # pragma: no cover
  176. """Set the default behavior for a config environment to be secure.
  177. If Session.key/keyfile have not been set, set Session.key to
  178. a new random UUID.
  179. """
  180. warnings.warn("default_secure is deprecated", DeprecationWarning, stacklevel=2)
  181. if "Session" in cfg and ("key" in cfg.Session or "keyfile" in cfg.Session):
  182. return
  183. # key/keyfile not specified, generate new UUID:
  184. cfg.Session.key = new_id_bytes()
  185. def utcnow() -> datetime:
  186. """Return timezone-aware UTC timestamp"""
  187. return datetime.now(utc)
  188. # -----------------------------------------------------------------------------
  189. # Classes
  190. # -----------------------------------------------------------------------------
  191. class SessionFactory(LoggingConfigurable):
  192. """The Base class for configurables that have a Session, Context, logger,
  193. and IOLoop.
  194. """
  195. logname = Unicode("")
  196. @observe("logname")
  197. def _logname_changed(self, change: t.Any) -> None:
  198. self.log = logging.getLogger(change["new"])
  199. # not configurable:
  200. context = Instance("zmq.Context")
  201. def _context_default(self) -> zmq.Context:
  202. return zmq.Context()
  203. session = Instance("jupyter_client.session.Session", allow_none=True)
  204. loop = Instance("tornado.ioloop.IOLoop")
  205. def _loop_default(self) -> IOLoop:
  206. return IOLoop.current()
  207. def __init__(self, **kwargs: t.Any) -> None:
  208. """Initialize a session factory."""
  209. super().__init__(**kwargs)
  210. if self.session is None:
  211. # construct the session
  212. self.session = Session(**kwargs)
  213. class Message:
  214. """A simple message object that maps dict keys to attributes.
  215. A Message can be created from a dict and a dict from a Message instance
  216. simply by calling dict(msg_obj)."""
  217. def __init__(self, msg_dict: dict[str, t.Any]) -> None:
  218. """Initialize a message."""
  219. dct = self.__dict__
  220. for k, v in dict(msg_dict).items():
  221. if isinstance(v, dict):
  222. v = Message(v) # noqa
  223. dct[k] = v
  224. # Having this iterator lets dict(msg_obj) work out of the box.
  225. def __iter__(self) -> t.ItemsView[str, t.Any]:
  226. return iter(self.__dict__.items()) # type:ignore[return-value]
  227. def __repr__(self) -> str:
  228. return repr(self.__dict__)
  229. def __str__(self) -> str:
  230. return pprint.pformat(self.__dict__)
  231. def __contains__(self, k: object) -> bool:
  232. return k in self.__dict__
  233. def __getitem__(self, k: str) -> t.Any:
  234. return self.__dict__[k]
  235. def msg_header(
  236. msg_id: str, msg_type: str, username: str, session: Session | str
  237. ) -> dict[str, t.Any]:
  238. """Create a new message header"""
  239. date = utcnow()
  240. version = protocol_version
  241. return locals()
  242. def extract_header(msg_or_header: dict[str, t.Any]) -> dict[str, t.Any]:
  243. """Given a message or header, return the header."""
  244. if not msg_or_header:
  245. return {}
  246. try:
  247. # See if msg_or_header is the entire message.
  248. h = msg_or_header["header"]
  249. except KeyError:
  250. try:
  251. # See if msg_or_header is just the header
  252. h = msg_or_header["msg_id"]
  253. except KeyError:
  254. raise
  255. else:
  256. h = msg_or_header
  257. if not isinstance(h, dict):
  258. h = dict(h)
  259. return h
  260. class Session(Configurable):
  261. """Object for handling serialization and sending of messages.
  262. The Session object handles building messages and sending them
  263. with ZMQ sockets or ZMQStream objects. Objects can communicate with each
  264. other over the network via Session objects, and only need to work with the
  265. dict-based IPython message spec. The Session will handle
  266. serialization/deserialization, security, and metadata.
  267. Sessions support configurable serialization via packer/unpacker traits,
  268. and signing with HMAC digests via the key/keyfile traits.
  269. Parameters
  270. ----------
  271. debug : bool
  272. whether to trigger extra debugging statements
  273. packer/unpacker : str : 'orjson', 'json', 'pickle', 'msgpack' or import_string
  274. importstrings for methods to serialize message parts. If just
  275. 'json' or 'pickle', predefined JSON and pickle packers will be used.
  276. Otherwise, the entire importstring must be used.
  277. The functions must accept at least valid JSON input, and output *bytes*.
  278. For example, to use msgpack:
  279. packer = 'msgpack.packb', unpacker='msgpack.unpackb'
  280. pack/unpack : callables
  281. You can also set the pack/unpack callables for serialization directly.
  282. session : bytes
  283. the ID of this Session object. The default is to generate a new UUID.
  284. username : unicode
  285. username added to message headers. The default is to ask the OS.
  286. key : bytes
  287. The key used to initialize an HMAC signature. If unset, messages
  288. will not be signed or checked.
  289. keyfile : filepath
  290. The file containing a key. If this is set, `key` will be initialized
  291. to the contents of the file.
  292. """
  293. debug = Bool(False, config=True, help="""Debug output in the Session""")
  294. check_pid = Bool(
  295. True,
  296. config=True,
  297. help="""Whether to check PID to protect against calls after fork.
  298. This check can be disabled if fork-safety is handled elsewhere.
  299. """,
  300. )
  301. # serialization traits:
  302. packer = DottedObjectName(
  303. "orjson" if has_orjson else "json",
  304. config=True,
  305. help="""The name of the packer for serializing messages.
  306. Should be one of 'json', 'pickle', or an import name
  307. for a custom callable serializer.""",
  308. )
  309. unpacker = DottedObjectName(
  310. "orjson" if has_orjson else "json",
  311. config=True,
  312. help="""The name of the unpacker for unserializing messages.
  313. Only used with custom functions for `packer`.""",
  314. )
  315. pack = Callable(orjson_packer if has_orjson else json_packer) # the actual packer function
  316. unpack = Callable(
  317. orjson_unpacker if has_orjson else json_unpacker
  318. ) # the actual unpacker function
  319. @observe("packer", "unpacker")
  320. def _packer_unpacker_changed(self, change: t.Any) -> None:
  321. new = change["new"].lower()
  322. if new == "orjson" and has_orjson:
  323. self.pack, self.unpack = orjson_packer, orjson_unpacker
  324. elif new == "json" or new == "orjson":
  325. self.pack, self.unpack = json_packer, json_unpacker
  326. elif new == "pickle":
  327. self.pack, self.unpack = pickle_packer, pickle_unpacker
  328. elif new == "msgpack" and has_msgpack:
  329. self.pack, self.unpack = msgpack_packer, msgpack_unpacker
  330. else:
  331. obj = import_item(str(change["new"]))
  332. name = "pack" if change["name"] == "packer" else "unpack"
  333. self.set_trait(name, obj)
  334. return
  335. self.packer = self.unpacker = change["new"]
  336. session = CUnicode("", config=True, help="""The UUID identifying this session.""")
  337. def _session_default(self) -> str:
  338. u = new_id()
  339. self.bsession = u.encode("ascii")
  340. return u
  341. @observe("session")
  342. def _session_changed(self, change: t.Any) -> None:
  343. self.bsession = self.session.encode("ascii")
  344. # bsession is the session as bytes
  345. bsession = CBytes(b"")
  346. username = Unicode(
  347. os.environ.get("USER", "username"),
  348. help="""Username for the Session. Default is your system username.""",
  349. config=True,
  350. )
  351. metadata = Dict(
  352. {},
  353. config=True,
  354. help="Metadata dictionary, which serves as the default top-level metadata dict for each message.",
  355. )
  356. # if 0, no adapting to do.
  357. adapt_version = Integer(0)
  358. # message signature related traits:
  359. key = CBytes(config=True, help="""execution key, for signing messages.""")
  360. def _key_default(self) -> bytes:
  361. return new_id_bytes()
  362. @observe("key")
  363. def _key_changed(self, change: t.Any) -> None:
  364. self._new_auth()
  365. signature_scheme = Unicode(
  366. "hmac-sha256",
  367. config=True,
  368. help="""The digest scheme used to construct the message signatures.
  369. Must have the form 'hmac-HASH'.""",
  370. )
  371. @observe("signature_scheme")
  372. def _signature_scheme_changed(self, change: t.Any) -> None:
  373. new = change["new"]
  374. if not new.startswith("hmac-"):
  375. raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
  376. hash_name = new.split("-", 1)[1]
  377. try:
  378. self.digest_mod = getattr(hashlib, hash_name)
  379. except AttributeError as e:
  380. raise TraitError("hashlib has no such attribute: %s" % hash_name) from e
  381. self._new_auth()
  382. digest_mod = Any()
  383. def _digest_mod_default(self) -> t.Callable:
  384. return hashlib.sha256
  385. auth = Instance(hmac.HMAC, allow_none=True)
  386. def _new_auth(self) -> None:
  387. if self.key:
  388. self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
  389. else:
  390. self.auth = None
  391. digest_history = Set()
  392. digest_history_size = Integer(
  393. 2**16,
  394. config=True,
  395. help="""The maximum number of digests to remember.
  396. The digest history will be culled when it exceeds this value.
  397. """,
  398. )
  399. keyfile = Unicode("", config=True, help="""path to file containing execution key.""")
  400. @observe("keyfile")
  401. def _keyfile_changed(self, change: t.Any) -> None:
  402. with open(change["new"], "rb") as f:
  403. self.key = f.read().strip()
  404. # for protecting against sends from forks
  405. pid = Integer()
  406. # thresholds:
  407. copy_threshold = Integer(
  408. 2**16,
  409. config=True,
  410. help="Threshold (in bytes) beyond which a buffer should be sent without copying.",
  411. )
  412. buffer_threshold = Integer(
  413. MAX_BYTES,
  414. config=True,
  415. help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.",
  416. )
  417. item_threshold = Integer(
  418. MAX_ITEMS,
  419. config=True,
  420. help="""The maximum number of items for a container to be introspected for custom serialization.
  421. Containers larger than this are pickled outright.
  422. """,
  423. )
  424. def __init__(self, **kwargs: t.Any) -> None:
  425. """create a Session object
  426. Parameters
  427. ----------
  428. debug : bool
  429. whether to trigger extra debugging statements
  430. packer/unpacker : str : 'orjson', 'json', 'pickle', 'msgpack' or import_string
  431. importstrings for methods to serialize message parts. If just
  432. 'json' or 'pickle', predefined JSON and pickle packers will be used.
  433. Otherwise, the entire importstring must be used.
  434. The functions must accept at least valid JSON input, and output
  435. *bytes*.
  436. For example, to use msgpack:
  437. packer = 'msgpack.packb', unpacker='msgpack.unpackb'
  438. pack/unpack : callables
  439. You can also set the pack/unpack callables for serialization
  440. directly.
  441. session : unicode (must be ascii)
  442. the ID of this Session object. The default is to generate a new
  443. UUID.
  444. bsession : bytes
  445. The session as bytes
  446. username : unicode
  447. username added to message headers. The default is to ask the OS.
  448. key : bytes
  449. The key used to initialize an HMAC signature. If unset, messages
  450. will not be signed or checked.
  451. signature_scheme : str
  452. The message digest scheme. Currently must be of the form 'hmac-HASH',
  453. where 'HASH' is a hashing function available in Python's hashlib.
  454. The default is 'hmac-sha256'.
  455. This is ignored if 'key' is empty.
  456. keyfile : filepath
  457. The file containing a key. If this is set, `key` will be
  458. initialized to the contents of the file.
  459. """
  460. super().__init__(**kwargs)
  461. self._check_packers()
  462. self.none = self.pack({})
  463. # ensure self._session_default() if necessary, so bsession is defined:
  464. self.session # noqa
  465. self.pid = os.getpid()
  466. self._new_auth()
  467. if not self.key:
  468. get_logger().warning(
  469. "Message signing is disabled. This is insecure and not recommended!"
  470. )
  471. def clone(self) -> Session:
  472. """Create a copy of this Session
  473. Useful when connecting multiple times to a given kernel.
  474. This prevents a shared digest_history warning about duplicate digests
  475. due to multiple connections to IOPub in the same process.
  476. .. versionadded:: 5.1
  477. """
  478. # make a copy
  479. new_session = type(self)()
  480. for name in self.traits():
  481. setattr(new_session, name, getattr(self, name))
  482. # fork digest_history
  483. new_session.digest_history = set()
  484. new_session.digest_history.update(self.digest_history)
  485. return new_session
  486. message_count = 0
  487. @property
  488. def msg_id(self) -> str:
  489. message_number = self.message_count
  490. self.message_count += 1
  491. return f"{self.session}_{os.getpid()}_{message_number}"
  492. def _check_packers(self) -> None:
  493. """check packers for datetime support."""
  494. pack = self.pack
  495. unpack = self.unpack
  496. # check simple serialization
  497. msg_list = {"a": [1, "hi"]}
  498. try:
  499. packed = pack(msg_list)
  500. except Exception as e:
  501. msg = f"packer '{self.packer}' could not serialize a simple message: {e}"
  502. raise ValueError(msg) from e
  503. # ensure packed message is bytes
  504. if not isinstance(packed, bytes):
  505. raise ValueError("message packed to %r, but bytes are required" % type(packed))
  506. # check that unpack is pack's inverse
  507. try:
  508. unpacked = unpack(packed)
  509. assert unpacked == msg_list
  510. except Exception as e:
  511. msg = f"unpacker {self.unpacker!r} could not handle output from packer {self.packer!r}: {e}"
  512. raise ValueError(msg) from e
  513. # check datetime support
  514. msg_datetime = {"t": utcnow()}
  515. try:
  516. unpacked = unpack(pack(msg_datetime))
  517. if isinstance(unpacked["t"], datetime):
  518. msg = "Shouldn't deserialize to datetime"
  519. raise ValueError(msg)
  520. except Exception:
  521. self.pack = lambda o: pack(squash_dates(o))
  522. self.unpack = lambda s: unpack(s)
  523. def msg_header(self, msg_type: str) -> dict[str, t.Any]:
  524. """Create a header for a message type."""
  525. return msg_header(self.msg_id, msg_type, self.username, self.session)
  526. def msg(
  527. self,
  528. msg_type: str,
  529. content: dict | None = None,
  530. parent: dict[str, t.Any] | None = None,
  531. header: dict[str, t.Any] | None = None,
  532. metadata: dict[str, t.Any] | None = None,
  533. ) -> dict[str, t.Any]:
  534. """Return the nested message dict.
  535. This format is different from what is sent over the wire. The
  536. serialize/deserialize methods converts this nested message dict to the wire
  537. format, which is a list of message parts.
  538. """
  539. msg = {}
  540. header = self.msg_header(msg_type) if header is None else header
  541. msg["header"] = header
  542. msg["msg_id"] = header["msg_id"]
  543. msg["msg_type"] = header["msg_type"]
  544. msg["parent_header"] = {} if parent is None else extract_header(parent)
  545. msg["content"] = {} if content is None else content
  546. msg["metadata"] = self.metadata.copy()
  547. if metadata is not None:
  548. msg["metadata"].update(metadata)
  549. return msg
  550. def sign(self, msg_list: list) -> bytes:
  551. """Sign a message with HMAC digest. If no auth, return b''.
  552. Parameters
  553. ----------
  554. msg_list : list
  555. The [p_header,p_parent,p_content] part of the message list.
  556. """
  557. if self.auth is None:
  558. return b""
  559. h = self.auth.copy()
  560. for m in msg_list:
  561. h.update(m)
  562. return h.hexdigest().encode()
  563. def serialize(
  564. self,
  565. msg: dict[str, t.Any],
  566. ident: list[bytes] | bytes | None = None,
  567. ) -> list[bytes]:
  568. """Serialize the message components to bytes.
  569. This is roughly the inverse of deserialize. The serialize/deserialize
  570. methods work with full message lists, whereas pack/unpack work with
  571. the individual message parts in the message list.
  572. Parameters
  573. ----------
  574. msg : dict or Message
  575. The next message dict as returned by the self.msg method.
  576. Returns
  577. -------
  578. msg_list : list
  579. The list of bytes objects to be sent with the format::
  580. [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
  581. p_metadata, p_content, buffer1, buffer2, ...]
  582. In this list, the ``p_*`` entities are the packed or serialized
  583. versions, so if JSON is used, these are utf8 encoded JSON strings.
  584. """
  585. content = msg.get("content", {})
  586. if content is None:
  587. content = self.none
  588. elif isinstance(content, dict):
  589. content = self.pack(content)
  590. elif isinstance(content, bytes):
  591. # content is already packed, as in a relayed message
  592. pass
  593. elif isinstance(content, str):
  594. # should be bytes, but JSON often spits out unicode
  595. content = content.encode("utf8")
  596. else:
  597. raise TypeError("Content incorrect type: %s" % type(content))
  598. real_message = [
  599. self.pack(msg["header"]),
  600. self.pack(msg["parent_header"]),
  601. self.pack(msg["metadata"]),
  602. content,
  603. ]
  604. to_send = []
  605. if isinstance(ident, list):
  606. # accept list of idents
  607. to_send.extend(ident)
  608. elif ident is not None:
  609. to_send.append(ident)
  610. to_send.append(DELIM)
  611. signature = self.sign(real_message)
  612. to_send.append(signature)
  613. to_send.extend(real_message)
  614. return to_send
  615. def send(
  616. self,
  617. stream: zmq.sugar.socket.Socket | ZMQStream | None,
  618. msg_or_type: dict[str, t.Any] | str,
  619. content: dict[str, t.Any] | None = None,
  620. parent: dict[str, t.Any] | None = None,
  621. ident: bytes | list[bytes] | None = None,
  622. buffers: list[bytes | memoryview[bytes]] | None = None,
  623. track: bool = False,
  624. header: dict[str, t.Any] | None = None,
  625. metadata: dict[str, t.Any] | None = None,
  626. ) -> dict[str, t.Any] | None:
  627. """Build and send a message via stream or socket.
  628. The message format used by this function internally is as follows:
  629. [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
  630. buffer1,buffer2,...]
  631. The serialize/deserialize methods convert the nested message dict into this
  632. format.
  633. Parameters
  634. ----------
  635. stream : zmq.Socket or ZMQStream
  636. The socket-like object used to send the data.
  637. msg_or_type : str or Message/dict
  638. Normally, msg_or_type will be a msg_type unless a message is being
  639. sent more than once. If a header is supplied, this can be set to
  640. None and the msg_type will be pulled from the header.
  641. content : dict or None
  642. The content of the message (ignored if msg_or_type is a message).
  643. header : dict or None
  644. The header dict for the message (ignored if msg_to_type is a message).
  645. parent : Message or dict or None
  646. The parent or parent header describing the parent of this message
  647. (ignored if msg_or_type is a message).
  648. ident : bytes or list of bytes
  649. The zmq.IDENTITY routing path.
  650. metadata : dict or None
  651. The metadata describing the message
  652. buffers : list or None
  653. The already-serialized buffers to be appended to the message.
  654. track : bool
  655. Whether to track. Only for use with Sockets, because ZMQStream
  656. objects cannot track messages.
  657. Returns
  658. -------
  659. msg : dict
  660. The constructed message.
  661. """
  662. if not isinstance(stream, zmq.Socket):
  663. # ZMQStreams and dummy sockets do not support tracking.
  664. track = False
  665. if isinstance(stream, zmq.asyncio.Socket):
  666. assert stream is not None
  667. stream = zmq.Socket.shadow(stream.underlying)
  668. if isinstance(msg_or_type, Message | dict):
  669. # We got a Message or message dict, not a msg_type so don't
  670. # build a new Message.
  671. msg = msg_or_type
  672. buffers = buffers or msg.get("buffers", [])
  673. else:
  674. msg = self.msg(
  675. msg_or_type,
  676. content=content,
  677. parent=parent,
  678. header=header,
  679. metadata=metadata,
  680. )
  681. if self.check_pid and os.getpid() != self.pid:
  682. get_logger().warning("WARNING: attempted to send message from fork\n%s", msg)
  683. return None
  684. buffers = [] if buffers is None else buffers
  685. for idx, buf in enumerate(buffers):
  686. if isinstance(buf, memoryview):
  687. view = buf
  688. else:
  689. try:
  690. # check to see if buf supports the buffer protocol.
  691. view = memoryview(buf)
  692. except TypeError as e:
  693. emsg = "Buffer objects must support the buffer protocol."
  694. raise TypeError(emsg) from e
  695. if not view.contiguous:
  696. # zmq requires memoryviews to be contiguous
  697. raise ValueError("Buffer %i (%r) is not contiguous" % (idx, buf))
  698. if self.adapt_version:
  699. msg = adapt(msg, self.adapt_version)
  700. to_send = self.serialize(msg, ident)
  701. to_send.extend(buffers) # type: ignore[arg-type]
  702. longest = max([len(s) for s in to_send])
  703. copy = longest < self.copy_threshold
  704. if stream and buffers and track and not copy:
  705. # only really track when we are doing zero-copy buffers
  706. tracker = stream.send_multipart(to_send, copy=False, track=True)
  707. elif stream:
  708. # use dummy tracker, which will be done immediately
  709. tracker = DONE
  710. stream.send_multipart(to_send, copy=copy)
  711. else:
  712. tracker = DONE
  713. if self.debug:
  714. pprint.pprint(msg) # noqa
  715. pprint.pprint(to_send) # noqa
  716. pprint.pprint(buffers) # noqa
  717. msg["tracker"] = tracker
  718. return msg
  719. def send_raw(
  720. self,
  721. stream: zmq.sugar.socket.Socket,
  722. msg_list: list,
  723. flags: int = 0,
  724. copy: bool = True,
  725. ident: bytes | list[bytes] | None = None,
  726. ) -> None:
  727. """Send a raw message via ident path.
  728. This method is used to send a already serialized message.
  729. Parameters
  730. ----------
  731. stream : ZMQStream or Socket
  732. The ZMQ stream or socket to use for sending the message.
  733. msg_list : list
  734. The serialized list of messages to send. This only includes the
  735. [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
  736. the message.
  737. ident : ident or list
  738. A single ident or a list of idents to use in sending.
  739. """
  740. to_send = []
  741. if isinstance(ident, bytes):
  742. ident = [ident]
  743. if ident is not None:
  744. to_send.extend(ident)
  745. to_send.append(DELIM)
  746. # Don't include buffers in signature (per spec).
  747. to_send.append(self.sign(msg_list[0:4]))
  748. to_send.extend(msg_list)
  749. if isinstance(stream, zmq.asyncio.Socket):
  750. stream = zmq.Socket.shadow(stream.underlying)
  751. stream.send_multipart(to_send, flags, copy=copy)
  752. def recv(
  753. self,
  754. socket: zmq.sugar.socket.Socket,
  755. mode: int = zmq.NOBLOCK,
  756. content: bool = True,
  757. copy: bool = True,
  758. ) -> tuple[list[bytes] | None, dict[str, t.Any] | None]:
  759. """Receive and unpack a message.
  760. Parameters
  761. ----------
  762. socket : ZMQStream or Socket
  763. The socket or stream to use in receiving.
  764. Returns
  765. -------
  766. [idents], msg
  767. [idents] is a list of idents and msg is a nested message dict of
  768. same format as self.msg returns.
  769. """
  770. if isinstance(socket, ZMQStream): # type:ignore[unreachable]
  771. socket = socket.socket # type:ignore[unreachable]
  772. if isinstance(socket, zmq.asyncio.Socket):
  773. socket = zmq.Socket.shadow(socket.underlying)
  774. try:
  775. msg_list = socket.recv_multipart(mode, copy=copy)
  776. except zmq.ZMQError as e:
  777. if e.errno == zmq.EAGAIN:
  778. # We can convert EAGAIN to None as we know in this case
  779. # recv_multipart won't return None.
  780. return None, None
  781. else:
  782. raise
  783. # split multipart message into identity list and message dict
  784. # invalid large messages can cause very expensive string comparisons
  785. idents, msg_list = self.feed_identities(msg_list, copy)
  786. try:
  787. return idents, self.deserialize(msg_list, content=content, copy=copy)
  788. except Exception as e:
  789. # TODO: handle it
  790. raise e
  791. def feed_identities(
  792. self, msg_list: list[bytes] | list[zmq.Message], copy: bool = True
  793. ) -> tuple[list[bytes], list[bytes] | list[zmq.Message]]:
  794. """Split the identities from the rest of the message.
  795. Feed until DELIM is reached, then return the prefix as idents and
  796. remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
  797. but that would be silly.
  798. Parameters
  799. ----------
  800. msg_list : a list of Message or bytes objects
  801. The message to be split.
  802. copy : bool
  803. flag determining whether the arguments are bytes or Messages
  804. Returns
  805. -------
  806. (idents, msg_list) : two lists
  807. idents will always be a list of bytes, each of which is a ZMQ
  808. identity. msg_list will be a list of bytes or zmq.Messages of the
  809. form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
  810. should be unpackable/unserializable via self.deserialize at this
  811. point.
  812. """
  813. if copy:
  814. msg_list = t.cast(t.List[bytes], msg_list)
  815. idx = msg_list.index(DELIM)
  816. return msg_list[:idx], msg_list[idx + 1 :]
  817. else:
  818. msg_list = t.cast(t.List[zmq.Message], msg_list)
  819. failed = True
  820. for idx, m in enumerate(msg_list): # noqa
  821. if m.bytes == DELIM:
  822. failed = False
  823. break
  824. if failed:
  825. msg = "DELIM not in msg_list"
  826. raise ValueError(msg)
  827. idents, msg_list = msg_list[:idx], msg_list[idx + 1 :]
  828. return [bytes(m.bytes) for m in idents], msg_list
  829. def _add_digest(self, signature: bytes) -> None:
  830. """add a digest to history to protect against replay attacks"""
  831. if self.digest_history_size == 0:
  832. # no history, never add digests
  833. return
  834. self.digest_history.add(signature)
  835. if len(self.digest_history) > self.digest_history_size:
  836. # threshold reached, cull 10%
  837. self._cull_digest_history()
  838. def _cull_digest_history(self) -> None:
  839. """cull the digest history
  840. Removes a randomly selected 10% of the digest history
  841. """
  842. current = len(self.digest_history)
  843. n_to_cull = max(int(current // 10), current - self.digest_history_size)
  844. if n_to_cull >= current:
  845. self.digest_history = set()
  846. return
  847. to_cull = random.sample(tuple(sorted(self.digest_history)), n_to_cull)
  848. self.digest_history.difference_update(to_cull)
  849. def deserialize(
  850. self,
  851. msg_list: list[bytes] | list[zmq.Message],
  852. content: bool = True,
  853. copy: bool = True,
  854. ) -> dict[str, t.Any]:
  855. """Unserialize a msg_list to a nested message dict.
  856. This is roughly the inverse of serialize. The serialize/deserialize
  857. methods work with full message lists, whereas pack/unpack work with
  858. the individual message parts in the message list.
  859. Parameters
  860. ----------
  861. msg_list : list of bytes or Message objects
  862. The list of message parts of the form [HMAC,p_header,p_parent,
  863. p_metadata,p_content,buffer1,buffer2,...].
  864. content : bool (True)
  865. Whether to unpack the content dict (True), or leave it packed
  866. (False).
  867. copy : bool (True)
  868. Whether msg_list contains bytes (True) or the non-copying Message
  869. objects in each place (False).
  870. Returns
  871. -------
  872. msg : dict
  873. The nested message dict with top-level keys [header, parent_header,
  874. content, buffers]. The buffers are returned as memoryviews.
  875. """
  876. minlen = 5
  877. message = {}
  878. if not copy:
  879. # pyzmq didn't copy the first parts of the message, so we'll do it
  880. msg_list = t.cast(t.List[zmq.Message], msg_list)
  881. msg_list_beginning = [bytes(msg.bytes) for msg in msg_list[:minlen]]
  882. msg_list = t.cast(t.List[bytes], msg_list)
  883. msg_list = msg_list_beginning + msg_list[minlen:]
  884. msg_list = t.cast(t.List[bytes], msg_list)
  885. if self.auth is not None:
  886. signature = msg_list[0]
  887. if not signature:
  888. msg = "Unsigned Message"
  889. raise ValueError(msg)
  890. if signature in self.digest_history:
  891. raise ValueError("Duplicate Signature: %r" % signature)
  892. if content:
  893. # Only store signature if we are unpacking content, don't store if just peeking.
  894. self._add_digest(signature)
  895. check = self.sign(msg_list[1:5])
  896. if not compare_digest(signature, check):
  897. msg = "Invalid Signature: %r" % signature
  898. raise ValueError(msg)
  899. if not len(msg_list) >= minlen:
  900. msg = "malformed message, must have at least %i elements" % minlen
  901. raise TypeError(msg)
  902. header = self.unpack(msg_list[1])
  903. message["header"] = extract_dates(header)
  904. message["msg_id"] = header["msg_id"]
  905. message["msg_type"] = header["msg_type"]
  906. message["parent_header"] = extract_dates(self.unpack(msg_list[2]))
  907. message["metadata"] = self.unpack(msg_list[3])
  908. if content:
  909. message["content"] = self.unpack(msg_list[4])
  910. else:
  911. message["content"] = msg_list[4]
  912. buffers = [memoryview(b) for b in msg_list[5:]]
  913. if buffers and buffers[0].shape is None:
  914. # force copy to workaround pyzmq #646
  915. msg_list = t.cast(t.List[zmq.Message], msg_list)
  916. buffers = [memoryview(bytes(b.bytes)) for b in msg_list[5:]]
  917. message["buffers"] = buffers
  918. if self.debug:
  919. pprint.pprint(message) # noqa
  920. # adapt to the current version
  921. return adapt(message)
  922. def unserialize(self, *args: t.Any, **kwargs: t.Any) -> dict[str, t.Any]:
  923. """**DEPRECATED** Use deserialize instead."""
  924. # pragma: no cover
  925. warnings.warn(
  926. "Session.unserialize is deprecated. Use Session.deserialize.",
  927. DeprecationWarning,
  928. stacklevel=2,
  929. )
  930. return self.deserialize(*args, **kwargs)