socket.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. """zmq Socket class"""
  2. # Copyright (C) PyZMQ Developers
  3. # Distributed under the terms of the Modified BSD License.
  4. import errno as errno_mod
  5. import warnings
  6. import zmq
  7. from zmq.constants import SocketOption, _OptType
  8. from zmq.error import ZMQError, _check_rc, _check_version
  9. from ._cffi import ffi
  10. from ._cffi import lib as C
  11. from .message import Frame
  12. from .utils import _retry_sys_call
  13. nsp = new_sizet_pointer = lambda length: ffi.new('size_t*', length)
  14. def new_uint64_pointer():
  15. return ffi.new('uint64_t*'), nsp(ffi.sizeof('uint64_t'))
  16. def new_int64_pointer():
  17. return ffi.new('int64_t*'), nsp(ffi.sizeof('int64_t'))
  18. def new_int_pointer():
  19. return ffi.new('int*'), nsp(ffi.sizeof('int'))
  20. def new_binary_data(length):
  21. return ffi.new(f'char[{length:d}]'), nsp(ffi.sizeof('char') * length)
  22. def value_uint64_pointer(val):
  23. return ffi.new('uint64_t*', val), ffi.sizeof('uint64_t')
  24. def value_int64_pointer(val):
  25. return ffi.new('int64_t*', val), ffi.sizeof('int64_t')
  26. def value_int_pointer(val):
  27. return ffi.new('int*', val), ffi.sizeof('int')
  28. def value_binary_data(val, length):
  29. return ffi.new(f'char[{length + 1:d}]', val), ffi.sizeof('char') * length
  30. _fd_size = ffi.sizeof('ZMQ_FD_T')
  31. ZMQ_FD_64BIT = _fd_size == 8
  32. IPC_PATH_MAX_LEN = C.get_ipc_path_max_len()
  33. def new_pointer_from_opt(option, length=0):
  34. opt_type = getattr(option, "_opt_type", _OptType.int)
  35. if opt_type == _OptType.int64 or (ZMQ_FD_64BIT and opt_type == _OptType.fd):
  36. return new_int64_pointer()
  37. elif opt_type == _OptType.bytes:
  38. return new_binary_data(length)
  39. else:
  40. # default
  41. return new_int_pointer()
  42. def value_from_opt_pointer(option, opt_pointer, length=0):
  43. try:
  44. option = SocketOption(option)
  45. except ValueError:
  46. # unrecognized option,
  47. # assume from the future,
  48. # let EINVAL raise
  49. opt_type = _OptType.int
  50. else:
  51. opt_type = option._opt_type
  52. if opt_type == _OptType.bytes:
  53. return ffi.buffer(opt_pointer, length)[:]
  54. else:
  55. return int(opt_pointer[0])
  56. def initialize_opt_pointer(option, value, length=0):
  57. opt_type = getattr(option, "_opt_type", _OptType.int)
  58. if opt_type == _OptType.int64 or (ZMQ_FD_64BIT and opt_type == _OptType.fd):
  59. return value_int64_pointer(value)
  60. elif opt_type == _OptType.bytes:
  61. return value_binary_data(value, length)
  62. else:
  63. return value_int_pointer(value)
  64. class Socket:
  65. context = None
  66. socket_type = None
  67. _zmq_socket = None
  68. _closed = None
  69. _ref = None
  70. _shadow = False
  71. _draft_poller = None
  72. _draft_poller_ptr = None
  73. copy_threshold = 0
  74. def __init__(self, context=None, socket_type=None, shadow=0, copy_threshold=None):
  75. if copy_threshold is None:
  76. copy_threshold = zmq.COPY_THRESHOLD
  77. self.copy_threshold = copy_threshold
  78. self.context = context
  79. self._draft_poller = self._draft_poller_ptr = None
  80. if shadow:
  81. self._zmq_socket = ffi.cast("void *", shadow)
  82. self._shadow = True
  83. else:
  84. self._shadow = False
  85. self._zmq_socket = C.zmq_socket(context._zmq_ctx, socket_type)
  86. if self._zmq_socket == ffi.NULL:
  87. raise ZMQError()
  88. self._closed = False
  89. @property
  90. def underlying(self):
  91. """The address of the underlying libzmq socket"""
  92. return int(ffi.cast('size_t', self._zmq_socket))
  93. def _check_closed_deep(self):
  94. """thorough check of whether the socket has been closed,
  95. even if by another entity (e.g. ctx.destroy).
  96. Only used by the `closed` property.
  97. returns True if closed, False otherwise
  98. """
  99. if self._closed:
  100. return True
  101. try:
  102. self.get(zmq.TYPE)
  103. except ZMQError as e:
  104. if e.errno == zmq.ENOTSOCK:
  105. self._closed = True
  106. return True
  107. elif e.errno == zmq.ETERM:
  108. pass
  109. else:
  110. raise
  111. return False
  112. @property
  113. def closed(self):
  114. return self._check_closed_deep()
  115. def close(self, linger=None):
  116. rc = 0
  117. if not self._closed and hasattr(self, '_zmq_socket'):
  118. if self._draft_poller_ptr is not None:
  119. rc = C.zmq_poller_destroy(self._draft_poller_ptr)
  120. self._draft_poller = self._draft_poller_ptr = None
  121. if self._zmq_socket is not None:
  122. if linger is not None:
  123. self.set(zmq.LINGER, linger)
  124. rc = C.zmq_close(self._zmq_socket)
  125. self._closed = True
  126. if rc < 0:
  127. _check_rc(rc)
  128. def bind(self, address):
  129. if isinstance(address, str):
  130. address_b = address.encode('utf8')
  131. else:
  132. address_b = address
  133. if isinstance(address, bytes):
  134. address = address_b.decode('utf8')
  135. rc = C.zmq_bind(self._zmq_socket, address_b)
  136. if rc < 0:
  137. if IPC_PATH_MAX_LEN and C.zmq_errno() == errno_mod.ENAMETOOLONG:
  138. path = address.split('://', 1)[-1]
  139. msg = (
  140. f'ipc path "{path}" is longer than {IPC_PATH_MAX_LEN} '
  141. 'characters (sizeof(sockaddr_un.sun_path)).'
  142. )
  143. raise ZMQError(C.zmq_errno(), msg=msg)
  144. elif C.zmq_errno() == errno_mod.ENOENT:
  145. path = address.split('://', 1)[-1]
  146. msg = f'No such file or directory for ipc path "{path}".'
  147. raise ZMQError(C.zmq_errno(), msg=msg)
  148. else:
  149. _check_rc(rc)
  150. def unbind(self, address):
  151. if isinstance(address, str):
  152. address = address.encode('utf8')
  153. rc = C.zmq_unbind(self._zmq_socket, address)
  154. _check_rc(rc)
  155. def connect(self, address):
  156. if isinstance(address, str):
  157. address = address.encode('utf8')
  158. rc = C.zmq_connect(self._zmq_socket, address)
  159. _check_rc(rc)
  160. def disconnect(self, address):
  161. if isinstance(address, str):
  162. address = address.encode('utf8')
  163. rc = C.zmq_disconnect(self._zmq_socket, address)
  164. _check_rc(rc)
  165. def set(self, option, value):
  166. length = None
  167. if isinstance(value, str):
  168. raise TypeError("unicode not allowed, use bytes")
  169. try:
  170. option = SocketOption(option)
  171. except ValueError:
  172. # unrecognized option,
  173. # assume from the future,
  174. # let EINVAL raise
  175. opt_type = _OptType.int
  176. else:
  177. opt_type = option._opt_type
  178. if isinstance(value, bytes):
  179. if opt_type != _OptType.bytes:
  180. raise TypeError(f"not a bytes sockopt: {option}")
  181. length = len(value)
  182. c_value_pointer, c_sizet = initialize_opt_pointer(option, value, length)
  183. _retry_sys_call(
  184. C.zmq_setsockopt,
  185. self._zmq_socket,
  186. option,
  187. ffi.cast('void*', c_value_pointer),
  188. c_sizet,
  189. )
  190. def get(self, option):
  191. try:
  192. option = SocketOption(option)
  193. except ValueError:
  194. # unrecognized option,
  195. # assume from the future,
  196. # let EINVAL raise
  197. opt_type = _OptType.int
  198. else:
  199. opt_type = option._opt_type
  200. if option == zmq.FD and self._draft_poller is not None:
  201. c_value_pointer, _ = new_pointer_from_opt(option)
  202. C.zmq_poller_fd(self._draft_poller, ffi.cast('void*', c_value_pointer))
  203. return int(c_value_pointer[0])
  204. c_value_pointer, c_sizet_pointer = new_pointer_from_opt(option, length=255)
  205. try:
  206. _retry_sys_call(
  207. C.zmq_getsockopt,
  208. self._zmq_socket,
  209. option,
  210. c_value_pointer,
  211. c_sizet_pointer,
  212. )
  213. except ZMQError as e:
  214. if (
  215. option == SocketOption.FD
  216. and e.errno == zmq.Errno.EINVAL
  217. and self.get(SocketOption.THREAD_SAFE)
  218. ):
  219. _check_version((4, 3, 2), "draft socket FD support via zmq_poller_fd")
  220. if not zmq.DRAFT_API:
  221. raise RuntimeError("libzmq must be built with draft support")
  222. warnings.warn(zmq.error.DraftFDWarning(), stacklevel=2)
  223. # create a poller and retrieve its fd
  224. self._draft_poller_ptr = ffi.new("void*[1]")
  225. self._draft_poller_ptr[0] = self._draft_poller = C.zmq_poller_new()
  226. if self._draft_poller == ffi.NULL:
  227. # failed (why?), raise original error
  228. self._draft_poller_ptr = self._draft_poller = None
  229. raise
  230. # register self with poller
  231. rc = C.zmq_poller_add(
  232. self._draft_poller,
  233. self._zmq_socket,
  234. ffi.NULL,
  235. zmq.POLLIN | zmq.POLLOUT,
  236. )
  237. _check_rc(rc)
  238. # use poller fd as proxy for ours
  239. rc = C.zmq_poller_fd(
  240. self._draft_poller, ffi.cast('void *', c_value_pointer)
  241. )
  242. _check_rc(rc)
  243. return int(c_value_pointer[0])
  244. else:
  245. raise
  246. sz = c_sizet_pointer[0]
  247. v = value_from_opt_pointer(option, c_value_pointer, sz)
  248. if (
  249. option != zmq.SocketOption.ROUTING_ID
  250. and opt_type == _OptType.bytes
  251. and v.endswith(b'\0')
  252. ):
  253. v = v[:-1]
  254. return v
  255. def _send_copy(self, buf, flags):
  256. """Send a copy of a bufferable"""
  257. zmq_msg = ffi.new('zmq_msg_t*')
  258. if not isinstance(buf, bytes):
  259. # cast any bufferable data to bytes via memoryview
  260. buf = memoryview(buf).tobytes()
  261. c_message = ffi.new('char[]', buf)
  262. rc = C.zmq_msg_init_size(zmq_msg, len(buf))
  263. _check_rc(rc)
  264. C.memcpy(C.zmq_msg_data(zmq_msg), c_message, len(buf))
  265. _retry_sys_call(C.zmq_msg_send, zmq_msg, self._zmq_socket, flags)
  266. rc2 = C.zmq_msg_close(zmq_msg)
  267. _check_rc(rc2)
  268. def _send_frame(self, frame, flags):
  269. """Send a Frame on this socket in a non-copy manner."""
  270. # Always copy the Frame so the original message isn't garbage collected.
  271. # This doesn't do a real copy, just a reference.
  272. frame_copy = frame.fast_copy()
  273. zmq_msg = frame_copy.zmq_msg
  274. _retry_sys_call(C.zmq_msg_send, zmq_msg, self._zmq_socket, flags)
  275. tracker = frame_copy.tracker
  276. frame_copy.close()
  277. return tracker
  278. def send(self, data, flags=0, copy=False, track=False):
  279. if isinstance(data, str):
  280. raise TypeError("Message must be in bytes, not a unicode object")
  281. if copy and not isinstance(data, Frame):
  282. return self._send_copy(data, flags)
  283. else:
  284. close_frame = False
  285. if isinstance(data, Frame):
  286. if track and not data.tracker:
  287. raise ValueError('Not a tracked message')
  288. frame = data
  289. else:
  290. if self.copy_threshold:
  291. buf = memoryview(data)
  292. # always copy messages smaller than copy_threshold
  293. if buf.nbytes < self.copy_threshold:
  294. self._send_copy(buf, flags)
  295. return zmq._FINISHED_TRACKER
  296. frame = Frame(data, track=track, copy_threshold=self.copy_threshold)
  297. close_frame = True
  298. tracker = self._send_frame(frame, flags)
  299. if close_frame:
  300. frame.close()
  301. return tracker
  302. def recv(self, flags=0, copy=True, track=False):
  303. if copy:
  304. zmq_msg = ffi.new('zmq_msg_t*')
  305. C.zmq_msg_init(zmq_msg)
  306. else:
  307. frame = zmq.Frame(track=track)
  308. zmq_msg = frame.zmq_msg
  309. try:
  310. _retry_sys_call(C.zmq_msg_recv, zmq_msg, self._zmq_socket, flags)
  311. except Exception:
  312. if copy:
  313. C.zmq_msg_close(zmq_msg)
  314. raise
  315. if not copy:
  316. return frame
  317. _buffer = ffi.buffer(C.zmq_msg_data(zmq_msg), C.zmq_msg_size(zmq_msg))
  318. _bytes = _buffer[:]
  319. rc = C.zmq_msg_close(zmq_msg)
  320. _check_rc(rc)
  321. return _bytes
  322. def recv_into(self, buffer, /, *, nbytes: int = 0, flags: int = 0) -> int:
  323. view = memoryview(buffer)
  324. if not view.contiguous:
  325. raise BufferError("Can only recv_into contiguous buffers")
  326. if view.readonly:
  327. raise BufferError("Cannot recv_into readonly buffer")
  328. if nbytes < 0:
  329. raise ValueError(f"{nbytes=} must be non-negative")
  330. view_bytes = view.nbytes
  331. if nbytes == 0:
  332. nbytes = view_bytes
  333. elif nbytes > view_bytes:
  334. raise ValueError(f"{nbytes=} too big for memoryview of {view_bytes}B")
  335. c_buf = ffi.from_buffer(view)
  336. rc: int = _retry_sys_call(C.zmq_recv, self._zmq_socket, c_buf, nbytes, flags)
  337. _check_rc(rc)
  338. return rc
  339. def monitor(self, addr, events=-1):
  340. """s.monitor(addr, flags)
  341. Start publishing socket events on inproc.
  342. See libzmq docs for zmq_monitor for details.
  343. Note: requires libzmq >= 3.2
  344. Parameters
  345. ----------
  346. addr : str
  347. The inproc url used for monitoring. Passing None as
  348. the addr will cause an existing socket monitor to be
  349. deregistered.
  350. events : int [default: zmq.EVENT_ALL]
  351. The zmq event bitmask for which events will be sent to the monitor.
  352. """
  353. if events < 0:
  354. events = zmq.EVENT_ALL
  355. if addr is None:
  356. addr = ffi.NULL
  357. if isinstance(addr, str):
  358. addr = addr.encode('utf8')
  359. C.zmq_socket_monitor(self._zmq_socket, addr, events)
  360. __all__ = ['Socket', 'IPC_PATH_MAX_LEN']