thread.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. """ZAP Authenticator in a Python Thread.
  2. .. versionadded:: 14.1
  3. """
  4. # Copyright (C) PyZMQ Developers
  5. # Distributed under the terms of the Modified BSD License.
  6. import asyncio
  7. from threading import Event, Thread
  8. from typing import Any, List, Optional
  9. import zmq
  10. import zmq.asyncio
  11. from .base import Authenticator
  12. class AuthenticationThread(Thread):
  13. """A Thread for running a zmq Authenticator
  14. This is run in the background by ThreadAuthenticator
  15. """
  16. pipe: zmq.Socket
  17. loop: asyncio.AbstractEventLoop
  18. authenticator: Authenticator
  19. poller: Optional[zmq.asyncio.Poller] = None
  20. def __init__(
  21. self,
  22. authenticator: Authenticator,
  23. pipe: zmq.Socket,
  24. ) -> None:
  25. super().__init__(daemon=True)
  26. self.authenticator = authenticator
  27. self.log = authenticator.log
  28. self.pipe = pipe
  29. self.started = Event()
  30. def run(self) -> None:
  31. """Start the Authentication Agent thread task"""
  32. loop = asyncio.new_event_loop()
  33. try:
  34. loop.run_until_complete(self._run())
  35. finally:
  36. if self.pipe:
  37. self.pipe.close()
  38. self.pipe = None # type: ignore
  39. loop.close()
  40. async def _run(self):
  41. self.poller = zmq.asyncio.Poller()
  42. self.poller.register(self.pipe, zmq.POLLIN)
  43. self.poller.register(self.authenticator.zap_socket, zmq.POLLIN)
  44. self.started.set()
  45. while True:
  46. events = dict(await self.poller.poll())
  47. if self.pipe in events:
  48. msg = self.pipe.recv_multipart()
  49. if self._handle_pipe_message(msg):
  50. return
  51. if self.authenticator.zap_socket in events:
  52. msg = self.authenticator.zap_socket.recv_multipart()
  53. await self.authenticator.handle_zap_message(msg)
  54. def _handle_pipe_message(self, msg: List[bytes]) -> bool:
  55. command = msg[0]
  56. self.log.debug("auth received API command %r", command)
  57. if command == b'TERMINATE':
  58. return True
  59. else:
  60. self.log.error("Invalid auth command from API: %r", command)
  61. self.pipe.send(b'ERROR')
  62. return False
  63. class ThreadAuthenticator(Authenticator):
  64. """Run ZAP authentication in a background thread"""
  65. pipe: "zmq.Socket"
  66. pipe_endpoint: str = ''
  67. thread: AuthenticationThread
  68. def __init__(
  69. self,
  70. context: Optional["zmq.Context"] = None,
  71. encoding: str = 'utf-8',
  72. log: Any = None,
  73. ):
  74. super().__init__(context=context, encoding=encoding, log=log)
  75. self.pipe = None # type: ignore
  76. self.pipe_endpoint = f"inproc://{id(self)}.inproc"
  77. self.thread = None # type: ignore
  78. def start(self) -> None:
  79. """Start the authentication thread"""
  80. # start the Authenticator
  81. super().start()
  82. # create a socket pair to communicate with auth thread.
  83. self.pipe = self.context.socket(zmq.PAIR, socket_class=zmq.Socket)
  84. self.pipe.linger = 1
  85. self.pipe.bind(self.pipe_endpoint)
  86. thread_pipe = self.context.socket(zmq.PAIR, socket_class=zmq.Socket)
  87. thread_pipe.linger = 1
  88. thread_pipe.connect(self.pipe_endpoint)
  89. self.thread = AuthenticationThread(authenticator=self, pipe=thread_pipe)
  90. self.thread.start()
  91. if not self.thread.started.wait(timeout=10):
  92. raise RuntimeError("Authenticator thread failed to start")
  93. def stop(self) -> None:
  94. """Stop the authentication thread"""
  95. if self.pipe:
  96. self.pipe.send(b'TERMINATE')
  97. if self.is_alive():
  98. self.thread.join()
  99. self.thread = None # type: ignore
  100. self.pipe.close()
  101. self.pipe = None # type: ignore
  102. super().stop()
  103. def is_alive(self) -> bool:
  104. """Is the ZAP thread currently running?"""
  105. return bool(self.thread and self.thread.is_alive())
  106. def __del__(self) -> None:
  107. self.stop()
  108. __all__ = ['ThreadAuthenticator']