tracker.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. """Tracker for zero-copy messages with 0MQ."""
  2. # Copyright (C) PyZMQ Developers
  3. # Distributed under the terms of the Modified BSD License.
  4. from __future__ import annotations
  5. import time
  6. from threading import Event
  7. from zmq.backend import Frame
  8. from zmq.error import NotDone
  9. class MessageTracker:
  10. """A class for tracking if 0MQ is done using one or more messages.
  11. When you send a 0MQ message, it is not sent immediately. The 0MQ IO thread
  12. sends the message at some later time. Often you want to know when 0MQ has
  13. actually sent the message though. This is complicated by the fact that
  14. a single 0MQ message can be sent multiple times using different sockets.
  15. This class allows you to track all of the 0MQ usages of a message.
  16. Parameters
  17. ----------
  18. towatch : Event, MessageTracker, zmq.Frame
  19. This objects to track. This class can track the low-level
  20. Events used by the Message class, other MessageTrackers or
  21. actual Messages.
  22. """
  23. events: set[Event]
  24. peers: set[MessageTracker]
  25. def __init__(self, *towatch: tuple[MessageTracker | Event | Frame]):
  26. """Create a message tracker to track a set of messages.
  27. Parameters
  28. ----------
  29. *towatch : tuple of Event, MessageTracker, Message instances.
  30. This list of objects to track. This class can track the low-level
  31. Events used by the Message class, other MessageTrackers or
  32. actual Messages.
  33. """
  34. self.events = set()
  35. self.peers = set()
  36. for obj in towatch:
  37. if isinstance(obj, Event):
  38. self.events.add(obj)
  39. elif isinstance(obj, MessageTracker):
  40. self.peers.add(obj)
  41. elif isinstance(obj, Frame):
  42. if not obj.tracker:
  43. raise ValueError("Not a tracked message")
  44. self.peers.add(obj.tracker)
  45. else:
  46. raise TypeError(f"Require Events or Message Frames, not {type(obj)}")
  47. @property
  48. def done(self):
  49. """Is 0MQ completely done with the message(s) being tracked?"""
  50. for evt in self.events:
  51. if not evt.is_set():
  52. return False
  53. for pm in self.peers:
  54. if not pm.done:
  55. return False
  56. return True
  57. def wait(self, timeout: float | int = -1):
  58. """Wait for 0MQ to be done with the message or until `timeout`.
  59. Parameters
  60. ----------
  61. timeout : float
  62. default: -1, which means wait forever.
  63. Maximum time in (s) to wait before raising NotDone.
  64. Returns
  65. -------
  66. None
  67. if done before `timeout`
  68. Raises
  69. ------
  70. NotDone
  71. if `timeout` reached before I am done.
  72. """
  73. tic = time.time()
  74. remaining: float
  75. if timeout is False or timeout < 0:
  76. remaining = 3600 * 24 * 7 # a week
  77. else:
  78. remaining = timeout
  79. for evt in self.events:
  80. if remaining < 0:
  81. raise NotDone
  82. evt.wait(timeout=remaining)
  83. if not evt.is_set():
  84. raise NotDone
  85. toc = time.time()
  86. remaining -= toc - tic
  87. tic = toc
  88. for peer in self.peers:
  89. if remaining < 0:
  90. raise NotDone
  91. peer.wait(timeout=remaining)
  92. toc = time.time()
  93. remaining -= toc - tic
  94. tic = toc
  95. _FINISHED_TRACKER = MessageTracker()
  96. __all__ = ['MessageTracker', '_FINISHED_TRACKER']