worker.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. from abc import ABC, abstractmethod
  2. import asyncio
  3. import os
  4. import threading
  5. from time import sleep, time
  6. from sentry_sdk._queue import Queue, FullError
  7. from sentry_sdk.utils import logger, mark_sentry_task_internal
  8. from sentry_sdk.consts import DEFAULT_QUEUE_SIZE
  9. from typing import TYPE_CHECKING
  10. if TYPE_CHECKING:
  11. from typing import Any
  12. from typing import Optional
  13. from typing import Callable
  14. _TERMINATOR = object()
  15. class Worker(ABC):
  16. """Base class for all workers."""
  17. @property
  18. @abstractmethod
  19. def is_alive(self) -> bool:
  20. """Whether the worker is alive and running."""
  21. pass
  22. @abstractmethod
  23. def kill(self) -> None:
  24. """Kill the worker. It will not process any more events."""
  25. pass
  26. def flush(
  27. self, timeout: float, callback: "Optional[Callable[[int, float], Any]]" = None
  28. ) -> None:
  29. """Flush the worker, blocking until done or timeout is reached."""
  30. return None
  31. @abstractmethod
  32. def full(self) -> bool:
  33. """Whether the worker's queue is full."""
  34. pass
  35. @abstractmethod
  36. def submit(self, callback: "Callable[[], Any]") -> bool:
  37. """Schedule a callback. Returns True if queued, False if full."""
  38. pass
  39. class BackgroundWorker(Worker):
  40. def __init__(self, queue_size: int = DEFAULT_QUEUE_SIZE) -> None:
  41. self._queue: "Queue" = Queue(queue_size)
  42. self._lock = threading.Lock()
  43. self._thread: "Optional[threading.Thread]" = None
  44. self._thread_for_pid: "Optional[int]" = None
  45. @property
  46. def is_alive(self) -> bool:
  47. if self._thread_for_pid != os.getpid():
  48. return False
  49. if not self._thread:
  50. return False
  51. return self._thread.is_alive()
  52. def _ensure_thread(self) -> None:
  53. if not self.is_alive:
  54. self.start()
  55. def _timed_queue_join(self, timeout: float) -> bool:
  56. deadline = time() + timeout
  57. queue = self._queue
  58. queue.all_tasks_done.acquire()
  59. try:
  60. while queue.unfinished_tasks:
  61. delay = deadline - time()
  62. if delay <= 0:
  63. return False
  64. queue.all_tasks_done.wait(timeout=delay)
  65. return True
  66. finally:
  67. queue.all_tasks_done.release()
  68. def start(self) -> None:
  69. with self._lock:
  70. if not self.is_alive:
  71. self._thread = threading.Thread(
  72. target=self._target, name="sentry-sdk.BackgroundWorker"
  73. )
  74. self._thread.daemon = True
  75. try:
  76. self._thread.start()
  77. self._thread_for_pid = os.getpid()
  78. except RuntimeError:
  79. # At this point we can no longer start because the interpreter
  80. # is already shutting down. Sadly at this point we can no longer
  81. # send out events.
  82. self._thread = None
  83. def kill(self) -> None:
  84. """
  85. Kill worker thread. Returns immediately. Not useful for
  86. waiting on shutdown for events, use `flush` for that.
  87. """
  88. logger.debug("background worker got kill request")
  89. with self._lock:
  90. if self._thread:
  91. try:
  92. self._queue.put_nowait(_TERMINATOR)
  93. except FullError:
  94. logger.debug("background worker queue full, kill failed")
  95. self._thread = None
  96. self._thread_for_pid = None
  97. def flush(self, timeout: float, callback: "Optional[Any]" = None) -> None:
  98. logger.debug("background worker got flush request")
  99. with self._lock:
  100. if self.is_alive and timeout > 0.0:
  101. self._wait_flush(timeout, callback)
  102. logger.debug("background worker flushed")
  103. def full(self) -> bool:
  104. return self._queue.full()
  105. def _wait_flush(self, timeout: float, callback: "Optional[Any]") -> None:
  106. initial_timeout = min(0.1, timeout)
  107. if not self._timed_queue_join(initial_timeout):
  108. pending = self._queue.qsize() + 1
  109. logger.debug("%d event(s) pending on flush", pending)
  110. if callback is not None:
  111. callback(pending, timeout)
  112. if not self._timed_queue_join(timeout - initial_timeout):
  113. pending = self._queue.qsize() + 1
  114. logger.error("flush timed out, dropped %s events", pending)
  115. def submit(self, callback: "Callable[[], Any]") -> bool:
  116. self._ensure_thread()
  117. try:
  118. self._queue.put_nowait(callback)
  119. return True
  120. except FullError:
  121. return False
  122. def _target(self) -> None:
  123. while True:
  124. callback = self._queue.get()
  125. try:
  126. if callback is _TERMINATOR:
  127. break
  128. try:
  129. callback()
  130. except Exception:
  131. logger.error("Failed processing job", exc_info=True)
  132. finally:
  133. self._queue.task_done()
  134. sleep(0)
  135. class AsyncWorker(Worker):
  136. def __init__(self, queue_size: int = DEFAULT_QUEUE_SIZE) -> None:
  137. self._queue: "Optional[asyncio.Queue[Any]]" = None
  138. self._queue_size = queue_size
  139. self._task: "Optional[asyncio.Task[None]]" = None
  140. # Event loop needs to remain in the same process
  141. self._task_for_pid: "Optional[int]" = None
  142. self._loop: "Optional[asyncio.AbstractEventLoop]" = None
  143. # Track active callback tasks so they have a strong reference and can be cancelled on kill
  144. self._active_tasks: "set[asyncio.Task[None]]" = set()
  145. @property
  146. def is_alive(self) -> bool:
  147. if self._task_for_pid != os.getpid():
  148. return False
  149. if not self._task or not self._loop:
  150. return False
  151. return self._loop.is_running() and not self._task.done()
  152. def kill(self) -> None:
  153. if self._task:
  154. # Cancel the main consumer task to prevent duplicate consumers
  155. self._task.cancel()
  156. # Also cancel any active callback tasks
  157. # Avoid modifying the set while cancelling tasks
  158. tasks_to_cancel = set(self._active_tasks)
  159. for task in tasks_to_cancel:
  160. task.cancel()
  161. self._active_tasks.clear()
  162. self._loop = None
  163. self._task = None
  164. self._task_for_pid = None
  165. def start(self) -> None:
  166. if not self.is_alive:
  167. try:
  168. self._loop = asyncio.get_running_loop()
  169. # Always create a fresh queue on start to avoid stale items
  170. self._queue = asyncio.Queue(maxsize=self._queue_size)
  171. with mark_sentry_task_internal():
  172. self._task = self._loop.create_task(self._target())
  173. self._task_for_pid = os.getpid()
  174. except RuntimeError:
  175. # There is no event loop running
  176. logger.warning("No event loop running, async worker not started")
  177. self._loop = None
  178. self._task = None
  179. self._task_for_pid = None
  180. def full(self) -> bool:
  181. if self._queue is None:
  182. return True
  183. return self._queue.full()
  184. def _ensure_task(self) -> None:
  185. if not self.is_alive:
  186. self.start()
  187. async def _wait_flush(
  188. self, timeout: float, callback: "Optional[Any]" = None
  189. ) -> None:
  190. if not self._loop or not self._loop.is_running() or self._queue is None:
  191. return
  192. initial_timeout = min(0.1, timeout)
  193. # Timeout on the join
  194. try:
  195. await asyncio.wait_for(self._queue.join(), timeout=initial_timeout)
  196. except asyncio.TimeoutError:
  197. pending = self._queue.qsize() + len(self._active_tasks)
  198. logger.debug("%d event(s) pending on flush", pending)
  199. if callback is not None:
  200. callback(pending, timeout)
  201. try:
  202. remaining_timeout = timeout - initial_timeout
  203. await asyncio.wait_for(self._queue.join(), timeout=remaining_timeout)
  204. except asyncio.TimeoutError:
  205. pending = self._queue.qsize() + len(self._active_tasks)
  206. logger.error("flush timed out, dropped %s events", pending)
  207. def flush( # type: ignore[override]
  208. self, timeout: float, callback: "Optional[Any]" = None
  209. ) -> "Optional[asyncio.Task[None]]":
  210. if self.is_alive and timeout > 0.0 and self._loop and self._loop.is_running():
  211. with mark_sentry_task_internal():
  212. return self._loop.create_task(self._wait_flush(timeout, callback))
  213. return None
  214. def submit(self, callback: "Callable[[], Any]") -> bool:
  215. self._ensure_task()
  216. if self._queue is None:
  217. return False
  218. try:
  219. self._queue.put_nowait(callback)
  220. return True
  221. except asyncio.QueueFull:
  222. return False
  223. async def _target(self) -> None:
  224. if self._queue is None:
  225. return
  226. try:
  227. while True:
  228. callback = await self._queue.get()
  229. if callback is _TERMINATOR:
  230. self._queue.task_done()
  231. break
  232. # Firing tasks instead of awaiting them allows for concurrent requests
  233. with mark_sentry_task_internal():
  234. task = asyncio.create_task(self._process_callback(callback))
  235. # Create a strong reference to the task so it can be cancelled on kill
  236. # and does not get garbage collected while running
  237. self._active_tasks.add(task)
  238. # Capture queue ref at dispatch time so done callbacks use the
  239. # correct queue even if kill()/start() replace self._queue.
  240. queue_ref = self._queue
  241. task.add_done_callback(lambda t: self._on_task_complete(t, queue_ref))
  242. # Yield to let the event loop run other tasks
  243. await asyncio.sleep(0)
  244. except asyncio.CancelledError:
  245. pass # Expected during kill()
  246. async def _process_callback(self, callback: "Callable[[], Any]") -> None:
  247. # Callback is an async coroutine, need to await it
  248. await callback()
  249. def _on_task_complete(
  250. self,
  251. task: "asyncio.Task[None]",
  252. queue: "Optional[asyncio.Queue[Any]]" = None,
  253. ) -> None:
  254. try:
  255. task.result()
  256. except asyncio.CancelledError:
  257. pass # Task was cancelled, expected during shutdown
  258. except Exception:
  259. logger.error("Failed processing job", exc_info=True)
  260. finally:
  261. # Mark the task as done and remove it from the active tasks set
  262. # Use the queue reference captured at dispatch time, not self._queue,
  263. # to avoid calling task_done() on a different queue after kill()/start().
  264. if queue is not None:
  265. queue.task_done()
  266. self._active_tasks.discard(task)