| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435 |
- """zmq Socket class"""
- # Copyright (C) PyZMQ Developers
- # Distributed under the terms of the Modified BSD License.
- import errno as errno_mod
- import warnings
- import zmq
- from zmq.constants import SocketOption, _OptType
- from zmq.error import ZMQError, _check_rc, _check_version
- from ._cffi import ffi
- from ._cffi import lib as C
- from .message import Frame
- from .utils import _retry_sys_call
- nsp = new_sizet_pointer = lambda length: ffi.new('size_t*', length)
- def new_uint64_pointer():
- return ffi.new('uint64_t*'), nsp(ffi.sizeof('uint64_t'))
- def new_int64_pointer():
- return ffi.new('int64_t*'), nsp(ffi.sizeof('int64_t'))
- def new_int_pointer():
- return ffi.new('int*'), nsp(ffi.sizeof('int'))
- def new_binary_data(length):
- return ffi.new(f'char[{length:d}]'), nsp(ffi.sizeof('char') * length)
- def value_uint64_pointer(val):
- return ffi.new('uint64_t*', val), ffi.sizeof('uint64_t')
- def value_int64_pointer(val):
- return ffi.new('int64_t*', val), ffi.sizeof('int64_t')
- def value_int_pointer(val):
- return ffi.new('int*', val), ffi.sizeof('int')
- def value_binary_data(val, length):
- return ffi.new(f'char[{length + 1:d}]', val), ffi.sizeof('char') * length
- _fd_size = ffi.sizeof('ZMQ_FD_T')
- ZMQ_FD_64BIT = _fd_size == 8
- IPC_PATH_MAX_LEN = C.get_ipc_path_max_len()
- def new_pointer_from_opt(option, length=0):
- opt_type = getattr(option, "_opt_type", _OptType.int)
- if opt_type == _OptType.int64 or (ZMQ_FD_64BIT and opt_type == _OptType.fd):
- return new_int64_pointer()
- elif opt_type == _OptType.bytes:
- return new_binary_data(length)
- else:
- # default
- return new_int_pointer()
- def value_from_opt_pointer(option, opt_pointer, length=0):
- try:
- option = SocketOption(option)
- except ValueError:
- # unrecognized option,
- # assume from the future,
- # let EINVAL raise
- opt_type = _OptType.int
- else:
- opt_type = option._opt_type
- if opt_type == _OptType.bytes:
- return ffi.buffer(opt_pointer, length)[:]
- else:
- return int(opt_pointer[0])
- def initialize_opt_pointer(option, value, length=0):
- opt_type = getattr(option, "_opt_type", _OptType.int)
- if opt_type == _OptType.int64 or (ZMQ_FD_64BIT and opt_type == _OptType.fd):
- return value_int64_pointer(value)
- elif opt_type == _OptType.bytes:
- return value_binary_data(value, length)
- else:
- return value_int_pointer(value)
- class Socket:
- context = None
- socket_type = None
- _zmq_socket = None
- _closed = None
- _ref = None
- _shadow = False
- _draft_poller = None
- _draft_poller_ptr = None
- copy_threshold = 0
- def __init__(self, context=None, socket_type=None, shadow=0, copy_threshold=None):
- if copy_threshold is None:
- copy_threshold = zmq.COPY_THRESHOLD
- self.copy_threshold = copy_threshold
- self.context = context
- self._draft_poller = self._draft_poller_ptr = None
- if shadow:
- self._zmq_socket = ffi.cast("void *", shadow)
- self._shadow = True
- else:
- self._shadow = False
- self._zmq_socket = C.zmq_socket(context._zmq_ctx, socket_type)
- if self._zmq_socket == ffi.NULL:
- raise ZMQError()
- self._closed = False
- @property
- def underlying(self):
- """The address of the underlying libzmq socket"""
- return int(ffi.cast('size_t', self._zmq_socket))
- def _check_closed_deep(self):
- """thorough check of whether the socket has been closed,
- even if by another entity (e.g. ctx.destroy).
- Only used by the `closed` property.
- returns True if closed, False otherwise
- """
- if self._closed:
- return True
- try:
- self.get(zmq.TYPE)
- except ZMQError as e:
- if e.errno == zmq.ENOTSOCK:
- self._closed = True
- return True
- elif e.errno == zmq.ETERM:
- pass
- else:
- raise
- return False
- @property
- def closed(self):
- return self._check_closed_deep()
- def close(self, linger=None):
- rc = 0
- if not self._closed and hasattr(self, '_zmq_socket'):
- if self._draft_poller_ptr is not None:
- rc = C.zmq_poller_destroy(self._draft_poller_ptr)
- self._draft_poller = self._draft_poller_ptr = None
- if self._zmq_socket is not None:
- if linger is not None:
- self.set(zmq.LINGER, linger)
- rc = C.zmq_close(self._zmq_socket)
- self._closed = True
- if rc < 0:
- _check_rc(rc)
- def bind(self, address):
- if isinstance(address, str):
- address_b = address.encode('utf8')
- else:
- address_b = address
- if isinstance(address, bytes):
- address = address_b.decode('utf8')
- rc = C.zmq_bind(self._zmq_socket, address_b)
- if rc < 0:
- if IPC_PATH_MAX_LEN and C.zmq_errno() == errno_mod.ENAMETOOLONG:
- path = address.split('://', 1)[-1]
- msg = (
- f'ipc path "{path}" is longer than {IPC_PATH_MAX_LEN} '
- 'characters (sizeof(sockaddr_un.sun_path)).'
- )
- raise ZMQError(C.zmq_errno(), msg=msg)
- elif C.zmq_errno() == errno_mod.ENOENT:
- path = address.split('://', 1)[-1]
- msg = f'No such file or directory for ipc path "{path}".'
- raise ZMQError(C.zmq_errno(), msg=msg)
- else:
- _check_rc(rc)
- def unbind(self, address):
- if isinstance(address, str):
- address = address.encode('utf8')
- rc = C.zmq_unbind(self._zmq_socket, address)
- _check_rc(rc)
- def connect(self, address):
- if isinstance(address, str):
- address = address.encode('utf8')
- rc = C.zmq_connect(self._zmq_socket, address)
- _check_rc(rc)
- def disconnect(self, address):
- if isinstance(address, str):
- address = address.encode('utf8')
- rc = C.zmq_disconnect(self._zmq_socket, address)
- _check_rc(rc)
- def set(self, option, value):
- length = None
- if isinstance(value, str):
- raise TypeError("unicode not allowed, use bytes")
- try:
- option = SocketOption(option)
- except ValueError:
- # unrecognized option,
- # assume from the future,
- # let EINVAL raise
- opt_type = _OptType.int
- else:
- opt_type = option._opt_type
- if isinstance(value, bytes):
- if opt_type != _OptType.bytes:
- raise TypeError(f"not a bytes sockopt: {option}")
- length = len(value)
- c_value_pointer, c_sizet = initialize_opt_pointer(option, value, length)
- _retry_sys_call(
- C.zmq_setsockopt,
- self._zmq_socket,
- option,
- ffi.cast('void*', c_value_pointer),
- c_sizet,
- )
- def get(self, option):
- try:
- option = SocketOption(option)
- except ValueError:
- # unrecognized option,
- # assume from the future,
- # let EINVAL raise
- opt_type = _OptType.int
- else:
- opt_type = option._opt_type
- if option == zmq.FD and self._draft_poller is not None:
- c_value_pointer, _ = new_pointer_from_opt(option)
- C.zmq_poller_fd(self._draft_poller, ffi.cast('void*', c_value_pointer))
- return int(c_value_pointer[0])
- c_value_pointer, c_sizet_pointer = new_pointer_from_opt(option, length=255)
- try:
- _retry_sys_call(
- C.zmq_getsockopt,
- self._zmq_socket,
- option,
- c_value_pointer,
- c_sizet_pointer,
- )
- except ZMQError as e:
- if (
- option == SocketOption.FD
- and e.errno == zmq.Errno.EINVAL
- and self.get(SocketOption.THREAD_SAFE)
- ):
- _check_version((4, 3, 2), "draft socket FD support via zmq_poller_fd")
- if not zmq.DRAFT_API:
- raise RuntimeError("libzmq must be built with draft support")
- warnings.warn(zmq.error.DraftFDWarning(), stacklevel=2)
- # create a poller and retrieve its fd
- self._draft_poller_ptr = ffi.new("void*[1]")
- self._draft_poller_ptr[0] = self._draft_poller = C.zmq_poller_new()
- if self._draft_poller == ffi.NULL:
- # failed (why?), raise original error
- self._draft_poller_ptr = self._draft_poller = None
- raise
- # register self with poller
- rc = C.zmq_poller_add(
- self._draft_poller,
- self._zmq_socket,
- ffi.NULL,
- zmq.POLLIN | zmq.POLLOUT,
- )
- _check_rc(rc)
- # use poller fd as proxy for ours
- rc = C.zmq_poller_fd(
- self._draft_poller, ffi.cast('void *', c_value_pointer)
- )
- _check_rc(rc)
- return int(c_value_pointer[0])
- else:
- raise
- sz = c_sizet_pointer[0]
- v = value_from_opt_pointer(option, c_value_pointer, sz)
- if (
- option != zmq.SocketOption.ROUTING_ID
- and opt_type == _OptType.bytes
- and v.endswith(b'\0')
- ):
- v = v[:-1]
- return v
- def _send_copy(self, buf, flags):
- """Send a copy of a bufferable"""
- zmq_msg = ffi.new('zmq_msg_t*')
- if not isinstance(buf, bytes):
- # cast any bufferable data to bytes via memoryview
- buf = memoryview(buf).tobytes()
- c_message = ffi.new('char[]', buf)
- rc = C.zmq_msg_init_size(zmq_msg, len(buf))
- _check_rc(rc)
- C.memcpy(C.zmq_msg_data(zmq_msg), c_message, len(buf))
- _retry_sys_call(C.zmq_msg_send, zmq_msg, self._zmq_socket, flags)
- rc2 = C.zmq_msg_close(zmq_msg)
- _check_rc(rc2)
- def _send_frame(self, frame, flags):
- """Send a Frame on this socket in a non-copy manner."""
- # Always copy the Frame so the original message isn't garbage collected.
- # This doesn't do a real copy, just a reference.
- frame_copy = frame.fast_copy()
- zmq_msg = frame_copy.zmq_msg
- _retry_sys_call(C.zmq_msg_send, zmq_msg, self._zmq_socket, flags)
- tracker = frame_copy.tracker
- frame_copy.close()
- return tracker
- def send(self, data, flags=0, copy=False, track=False):
- if isinstance(data, str):
- raise TypeError("Message must be in bytes, not a unicode object")
- if copy and not isinstance(data, Frame):
- return self._send_copy(data, flags)
- else:
- close_frame = False
- if isinstance(data, Frame):
- if track and not data.tracker:
- raise ValueError('Not a tracked message')
- frame = data
- else:
- if self.copy_threshold:
- buf = memoryview(data)
- # always copy messages smaller than copy_threshold
- if buf.nbytes < self.copy_threshold:
- self._send_copy(buf, flags)
- return zmq._FINISHED_TRACKER
- frame = Frame(data, track=track, copy_threshold=self.copy_threshold)
- close_frame = True
- tracker = self._send_frame(frame, flags)
- if close_frame:
- frame.close()
- return tracker
- def recv(self, flags=0, copy=True, track=False):
- if copy:
- zmq_msg = ffi.new('zmq_msg_t*')
- C.zmq_msg_init(zmq_msg)
- else:
- frame = zmq.Frame(track=track)
- zmq_msg = frame.zmq_msg
- try:
- _retry_sys_call(C.zmq_msg_recv, zmq_msg, self._zmq_socket, flags)
- except Exception:
- if copy:
- C.zmq_msg_close(zmq_msg)
- raise
- if not copy:
- return frame
- _buffer = ffi.buffer(C.zmq_msg_data(zmq_msg), C.zmq_msg_size(zmq_msg))
- _bytes = _buffer[:]
- rc = C.zmq_msg_close(zmq_msg)
- _check_rc(rc)
- return _bytes
- def recv_into(self, buffer, /, *, nbytes: int = 0, flags: int = 0) -> int:
- view = memoryview(buffer)
- if not view.contiguous:
- raise BufferError("Can only recv_into contiguous buffers")
- if view.readonly:
- raise BufferError("Cannot recv_into readonly buffer")
- if nbytes < 0:
- raise ValueError(f"{nbytes=} must be non-negative")
- view_bytes = view.nbytes
- if nbytes == 0:
- nbytes = view_bytes
- elif nbytes > view_bytes:
- raise ValueError(f"{nbytes=} too big for memoryview of {view_bytes}B")
- c_buf = ffi.from_buffer(view)
- rc: int = _retry_sys_call(C.zmq_recv, self._zmq_socket, c_buf, nbytes, flags)
- _check_rc(rc)
- return rc
- def monitor(self, addr, events=-1):
- """s.monitor(addr, flags)
- Start publishing socket events on inproc.
- See libzmq docs for zmq_monitor for details.
- Note: requires libzmq >= 3.2
- Parameters
- ----------
- addr : str
- The inproc url used for monitoring. Passing None as
- the addr will cause an existing socket monitor to be
- deregistered.
- events : int [default: zmq.EVENT_ALL]
- The zmq event bitmask for which events will be sent to the monitor.
- """
- if events < 0:
- events = zmq.EVENT_ALL
- if addr is None:
- addr = ffi.NULL
- if isinstance(addr, str):
- addr = addr.encode('utf8')
- C.zmq_socket_monitor(self._zmq_socket, addr, events)
- __all__ = ['Socket', 'IPC_PATH_MAX_LEN']
|