base_comm.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. """Default classes for Comm and CommManager, for usage in IPython."""
  2. # Copyright (c) IPython Development Team.
  3. # Distributed under the terms of the Modified BSD License.
  4. from __future__ import annotations
  5. import contextlib
  6. import logging
  7. import typing as t
  8. import uuid
  9. import comm
  10. if t.TYPE_CHECKING:
  11. from zmq.eventloop.zmqstream import ZMQStream
  12. logger = logging.getLogger("Comm")
  13. MessageType = t.Dict[str, t.Any]
  14. MaybeDict = t.Optional[t.Dict[str, t.Any]]
  15. BuffersType = t.Optional[t.List[bytes]]
  16. CommCallback = t.Callable[[MessageType], None]
  17. CommTargetCallback = t.Callable[["BaseComm", MessageType], None]
  18. class BaseComm:
  19. """Class for communicating between a Frontend and a Kernel
  20. Must be subclassed with a publish_msg method implementation which
  21. sends comm messages through the iopub channel.
  22. """
  23. def __init__(
  24. self,
  25. target_name: str = "comm",
  26. data: MaybeDict = None,
  27. metadata: MaybeDict = None,
  28. buffers: BuffersType = None,
  29. comm_id: str | None = None,
  30. primary: bool = True,
  31. target_module: str | None = None,
  32. topic: bytes | None = None,
  33. _open_data: MaybeDict = None,
  34. _close_data: MaybeDict = None,
  35. **kwargs: t.Any,
  36. ) -> None:
  37. super().__init__(**kwargs)
  38. self.comm_id = comm_id if comm_id else uuid.uuid4().hex
  39. self.primary = primary
  40. self.target_name = target_name
  41. self.target_module = target_module
  42. self.topic = topic if topic else (f"comm-{self.comm_id}").encode("ascii")
  43. self._open_data = _open_data if _open_data else {}
  44. self._close_data = _close_data if _close_data else {}
  45. self._msg_callback: CommCallback | None = None
  46. self._close_callback: CommCallback | None = None
  47. self._closed = True
  48. if self.primary:
  49. # I am primary, open my peer.
  50. self.open(data=data, metadata=metadata, buffers=buffers)
  51. else:
  52. self._closed = False
  53. def publish_msg(
  54. self,
  55. msg_type: str,
  56. data: MaybeDict = None,
  57. metadata: MaybeDict = None,
  58. buffers: BuffersType = None,
  59. **keys: t.Any,
  60. ) -> None:
  61. msg = "publish_msg Comm method is not implemented"
  62. raise NotImplementedError(msg)
  63. def __del__(self) -> None:
  64. """trigger close on gc"""
  65. with contextlib.suppress(Exception):
  66. # any number of things can have gone horribly wrong
  67. # when called during interpreter teardown
  68. self.close(deleting=True)
  69. # publishing messages
  70. def open(
  71. self, data: MaybeDict = None, metadata: MaybeDict = None, buffers: BuffersType = None
  72. ) -> None:
  73. """Open the frontend-side version of this comm"""
  74. if data is None:
  75. data = self._open_data
  76. comm_manager = comm.get_comm_manager()
  77. if comm_manager is None:
  78. msg = "Comms cannot be opened without a comm_manager." # type:ignore[unreachable]
  79. raise RuntimeError(msg)
  80. comm_manager.register_comm(self)
  81. try:
  82. self.publish_msg(
  83. "comm_open",
  84. data=data,
  85. metadata=metadata,
  86. buffers=buffers,
  87. target_name=self.target_name,
  88. target_module=self.target_module,
  89. )
  90. self._closed = False
  91. except Exception:
  92. comm_manager.unregister_comm(self)
  93. raise
  94. def close(
  95. self,
  96. data: MaybeDict = None,
  97. metadata: MaybeDict = None,
  98. buffers: BuffersType = None,
  99. deleting: bool = False,
  100. ) -> None:
  101. """Close the frontend-side version of this comm"""
  102. if self._closed:
  103. # only close once
  104. return
  105. self._closed = True
  106. if data is None:
  107. data = self._close_data
  108. self.publish_msg(
  109. "comm_close",
  110. data=data,
  111. metadata=metadata,
  112. buffers=buffers,
  113. )
  114. if not deleting:
  115. # If deleting, the comm can't be registered
  116. comm.get_comm_manager().unregister_comm(self)
  117. def send(
  118. self, data: MaybeDict = None, metadata: MaybeDict = None, buffers: BuffersType = None
  119. ) -> None:
  120. """Send a message to the frontend-side version of this comm"""
  121. self.publish_msg(
  122. "comm_msg",
  123. data=data,
  124. metadata=metadata,
  125. buffers=buffers,
  126. )
  127. # registering callbacks
  128. def on_close(self, callback: CommCallback | None) -> None:
  129. """Register a callback for comm_close
  130. Will be called with the `data` of the close message.
  131. Call `on_close(None)` to disable an existing callback.
  132. """
  133. self._close_callback = callback
  134. def on_msg(self, callback: CommCallback | None) -> None:
  135. """Register a callback for comm_msg
  136. Will be called with the `data` of any comm_msg messages.
  137. Call `on_msg(None)` to disable an existing callback.
  138. """
  139. self._msg_callback = callback
  140. # handling of incoming messages
  141. def handle_close(self, msg: MessageType) -> None:
  142. """Handle a comm_close message"""
  143. logger.debug("handle_close[%s](%s)", self.comm_id, msg)
  144. if self._close_callback:
  145. self._close_callback(msg)
  146. def handle_msg(self, msg: MessageType) -> None:
  147. """Handle a comm_msg message"""
  148. logger.debug("handle_msg[%s](%s)", self.comm_id, msg)
  149. if self._msg_callback:
  150. from IPython import get_ipython
  151. shell = get_ipython()
  152. if shell:
  153. shell.events.trigger("pre_execute")
  154. self._msg_callback(msg)
  155. if shell:
  156. shell.events.trigger("post_execute")
  157. class CommManager:
  158. """Default CommManager singleton implementation for Comms in the Kernel"""
  159. # Public APIs
  160. def __init__(self) -> None:
  161. self.comms: dict[str, BaseComm] = {}
  162. self.targets: dict[str, CommTargetCallback] = {}
  163. def register_target(self, target_name: str, f: CommTargetCallback | str) -> None:
  164. """Register a callable f for a given target name
  165. f will be called with two arguments when a comm_open message is received with `target`:
  166. - the Comm instance
  167. - the `comm_open` message itself.
  168. f can be a Python callable or an import string for one.
  169. """
  170. if isinstance(f, str):
  171. parts = f.rsplit(".", 1)
  172. if len(parts) == 2:
  173. # called with 'foo.bar....'
  174. package, obj = parts
  175. module = __import__(package, fromlist=[obj])
  176. try:
  177. f = getattr(module, obj)
  178. except AttributeError as e:
  179. error_msg = f"No module named {obj}"
  180. raise ImportError(error_msg) from e
  181. else:
  182. # called with un-dotted string
  183. f = __import__(parts[0])
  184. self.targets[target_name] = t.cast(CommTargetCallback, f)
  185. def unregister_target(self, target_name: str, f: CommTargetCallback) -> CommTargetCallback: # noqa: ARG002
  186. """Unregister a callable registered with register_target"""
  187. return self.targets.pop(target_name)
  188. def register_comm(self, comm: BaseComm) -> str:
  189. """Register a new comm"""
  190. comm_id = comm.comm_id
  191. self.comms[comm_id] = comm
  192. return comm_id
  193. def unregister_comm(self, comm: BaseComm) -> None:
  194. """Unregister a comm, and close its counterpart"""
  195. # unlike get_comm, this should raise a KeyError
  196. comm = self.comms.pop(comm.comm_id)
  197. def get_comm(self, comm_id: str) -> BaseComm | None:
  198. """Get a comm with a particular id
  199. Returns the comm if found, otherwise None.
  200. This will not raise an error,
  201. it will log messages if the comm cannot be found.
  202. """
  203. try:
  204. return self.comms[comm_id]
  205. except KeyError:
  206. logger.warning("No such comm: %s", comm_id)
  207. if logger.isEnabledFor(logging.DEBUG):
  208. # don't create the list of keys if debug messages aren't enabled
  209. logger.debug("Current comms: %s", list(self.comms.keys()))
  210. return None
  211. # Message handlers
  212. def comm_open(self, stream: ZMQStream, ident: str, msg: MessageType) -> None: # noqa: ARG002
  213. """Handler for comm_open messages"""
  214. from comm import create_comm
  215. content = msg["content"]
  216. comm_id = content["comm_id"]
  217. target_name = content["target_name"]
  218. f = self.targets.get(target_name, None)
  219. comm = create_comm(
  220. comm_id=comm_id,
  221. primary=False,
  222. target_name=target_name,
  223. )
  224. self.register_comm(comm)
  225. if f is None:
  226. logger.error("No such comm target registered: %s", target_name)
  227. else:
  228. try:
  229. f(comm, msg)
  230. return
  231. except Exception:
  232. logger.error("Exception opening comm with target: %s", target_name, exc_info=True)
  233. # Failure.
  234. try:
  235. comm.close()
  236. except Exception:
  237. logger.error(
  238. """Could not close comm during `comm_open` failure
  239. clean-up. The comm may not have been opened yet.""",
  240. exc_info=True,
  241. )
  242. def comm_msg(self, stream: ZMQStream, ident: str, msg: MessageType) -> None: # noqa: ARG002
  243. """Handler for comm_msg messages"""
  244. content = msg["content"]
  245. comm_id = content["comm_id"]
  246. comm = self.get_comm(comm_id)
  247. if comm is None:
  248. return
  249. try:
  250. comm.handle_msg(msg)
  251. except Exception:
  252. logger.error("Exception in comm_msg for %s", comm_id, exc_info=True)
  253. def comm_close(self, stream: ZMQStream, ident: str, msg: MessageType) -> None: # noqa: ARG002
  254. """Handler for comm_close messages"""
  255. content = msg["content"]
  256. comm_id = content["comm_id"]
  257. comm = self.get_comm(comm_id)
  258. if comm is None:
  259. return
  260. self.comms[comm_id]._closed = True
  261. del self.comms[comm_id]
  262. try:
  263. comm.handle_close(msg)
  264. except Exception:
  265. logger.error("Exception in comm_close for %s", comm_id, exc_info=True)
  266. __all__ = ["BaseComm", "CommManager"]