console_capture.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. """Module for intercepting stdout/stderr.
  2. This patches the `write()` method of `stdout` and `stderr` on import.
  3. Once patched, it is not possible to unpatch or repatch, though individual
  4. callbacks can be removed.
  5. We assume that all other writing methods on the object delegate to `write()`,
  6. like `writelines()`. This is not guaranteed to be true, but it is true for
  7. common implementations. In particular, CPython's implementation of IOBase's
  8. `writelines()` delegates to `write()`.
  9. It is important to note that this technique interacts poorly with other
  10. code that performs similar patching if it also allows unpatching as this
  11. discards our modification. This is why we patch on import and do not support
  12. unpatching:
  13. with contextlib.redirect_stderr(...):
  14. from ... import console_capture
  15. # Here, everything works fine.
  16. # Here, callbacks are never called again.
  17. In particular, it does not work with some combinations of pytest's
  18. `capfd` / `capsys` fixtures and pytest's `--capture` option.
  19. """
  20. from __future__ import annotations
  21. import contextlib
  22. import contextvars
  23. import logging
  24. import sys
  25. import threading
  26. from collections.abc import Iterator
  27. from typing import IO, AnyStr, Callable, Protocol
  28. from . import wb_logging
  29. _logger = logging.getLogger(__name__)
  30. class CannotCaptureConsoleError(Exception):
  31. """The module failed to patch stdout or stderr."""
  32. class _WriteCallback(Protocol):
  33. """A callback that receives intercepted bytes or string data.
  34. This may be called from any thread, but is only called from one thread
  35. at a time.
  36. Note on errors: Any error raised during the callback will clear all
  37. callbacks. This means that if a user presses Ctrl-C at an unlucky time
  38. during a run, we will stop uploading console output---but it's not
  39. likely to be a problem unless something catches the KeyboardInterrupt.
  40. Regular Exceptions are caught and logged instead of bubbling up to the
  41. user's print() statements; other exceptions like KeyboardInterrupt are
  42. re-raised.
  43. Callbacks should handle all exceptions---a callback that raises any
  44. Exception is considered buggy.
  45. """
  46. def __call__(
  47. self,
  48. data: bytes | str,
  49. written: int,
  50. /,
  51. ) -> None:
  52. """Intercept data passed to `write()`.
  53. See the protocol docstring for information about exceptions.
  54. Args:
  55. data: The object passed to stderr's or stdout's `write()`.
  56. written: The number of bytes or characters written.
  57. This is the return value of `write()`.
  58. """
  59. # See _enter_callbacks() for why this is an RLock.
  60. _module_rlock = threading.RLock()
  61. # See _enter_callbacks().
  62. _is_writing = False
  63. _is_caused_by_callback = contextvars.ContextVar(
  64. "_is_caused_by_callback",
  65. default=False,
  66. )
  67. _patch_exception: CannotCaptureConsoleError | None = None
  68. def _maybe_raise_patch_exception() -> None:
  69. """Raise _patch_exception if it's set."""
  70. if _patch_exception:
  71. # Each `raise` call modifies the exception's __traceback__, so we must
  72. # reset the traceback to reuse the exception.
  73. #
  74. # See https://bugs.python.org/issue45924 for an example of the problem.
  75. raise _patch_exception.with_traceback(None)
  76. _next_callback_id: int = 1
  77. _stdout_callbacks: dict[int, _WriteCallback] = {}
  78. _stderr_callbacks: dict[int, _WriteCallback] = {}
  79. def capture_stdout(callback: _WriteCallback) -> Callable[[], None]:
  80. """Install a callback that runs after every write to sys.stdout.
  81. Args:
  82. callback: A callback to invoke after running `sys.stdout.write`.
  83. Returns:
  84. A function to uninstall the callback.
  85. Raises:
  86. CannotCaptureConsoleError: If patching failed on import.
  87. """
  88. _maybe_raise_patch_exception()
  89. with _module_rlock:
  90. return _insert_disposably(
  91. _stdout_callbacks,
  92. callback,
  93. )
  94. def capture_stderr(callback: _WriteCallback) -> Callable[[], None]:
  95. """Install a callback that runs after every write to sys.sdterr.
  96. Args:
  97. callback: A callback to invoke after running `sys.stderr.write`.
  98. Returns:
  99. A function to uninstall the callback.
  100. Raises:
  101. CannotCaptureConsoleError: If patching failed on import.
  102. """
  103. _maybe_raise_patch_exception()
  104. with _module_rlock:
  105. return _insert_disposably(
  106. _stderr_callbacks,
  107. callback,
  108. )
  109. def _insert_disposably(
  110. callback_dict: dict[int, _WriteCallback],
  111. callback: _WriteCallback,
  112. ) -> Callable[[], None]:
  113. global _next_callback_id
  114. id = _next_callback_id
  115. _next_callback_id += 1
  116. disposed = False
  117. def dispose() -> None:
  118. nonlocal disposed
  119. with _module_rlock:
  120. if disposed:
  121. return
  122. callback_dict.pop(id, None)
  123. disposed = True
  124. callback_dict[id] = callback
  125. return dispose
  126. def _patch(
  127. stdout_or_stderr: IO[AnyStr],
  128. callbacks: dict[int, _WriteCallback],
  129. ) -> None:
  130. orig_write: Callable[[AnyStr], int]
  131. def write_with_callbacks(s: AnyStr, /) -> int:
  132. n = orig_write(s)
  133. with contextlib.ExitStack() as stack:
  134. stack.enter_context(_reset_on_exception())
  135. stack.enter_context(wb_logging.log_to_all_runs())
  136. callbacks_list = stack.enter_context(_enter_callbacks(callbacks))
  137. for cb in callbacks_list:
  138. cb(s, n)
  139. return n
  140. orig_write = stdout_or_stderr.write
  141. # mypy==1.14.1 fails to type-check this:
  142. # Incompatible types in assignment (expression has type
  143. # "Callable[[bytes], int]", variable has type overloaded function)
  144. stdout_or_stderr.write = write_with_callbacks # type: ignore
  145. @contextlib.contextmanager
  146. def _enter_callbacks(
  147. callbacks: dict[int, _WriteCallback],
  148. ) -> Iterator[list[_WriteCallback]]:
  149. """Returns a list of callbacks to invoke.
  150. This prevents deadlocks and some infinite loops by returning an empty list
  151. when:
  152. * A callback prints
  153. * A callback blocks on a thread that's printing
  154. * A callback schedules an async task that prints
  155. It is impossible to prevent all infinite loops: a callback could put
  156. a message into a queue, causing an unrelated thread to print later,
  157. invoking the same callback and repeating forever.
  158. """
  159. global _is_writing
  160. # NOTE 1: _is_writing
  161. #
  162. # The global _is_writing variable is necessary despite the contextvar
  163. # because it's possible to create a thread without copying the context.
  164. # This is the default behavior for threading.Thread() before Python 3.14.
  165. #
  166. # A side effect of it is that when multiple threads print simultaneously,
  167. # some messages will not be captured.
  168. #
  169. # NOTE 2: _module_rlock
  170. #
  171. # We use a reentrant lock primarily because GC can trigger on the current
  172. # thread at any time. Since GC can run arbitrary code via `__del__`, it may
  173. # print and hit this lock while it's already held.
  174. #
  175. # Technically, even the simple methods called while holding the lock
  176. # could be patched to print, but a reentrant lock just turns this
  177. # deadlock into an infinite loop.
  178. #
  179. # Text printed during GC may or may not be captured.
  180. # Assuming that the GC thread cannot itself be stolen, GC does not thwart
  181. # the infinite loop protections:
  182. #
  183. # * If _is_writing == True, GC will not print or touch _is_writing
  184. # * If _is_writing == False, GC will set it to True and reset it to
  185. # False before giving back the thread
  186. #
  187. # Specifically, there is no situation where _is_writing "spontaneously"
  188. # changes from True to False.
  189. with _module_rlock:
  190. if _is_writing or _is_caused_by_callback.get():
  191. callbacks_list = None
  192. else:
  193. callbacks_list = list(callbacks.values())
  194. _is_writing = True
  195. _is_caused_by_callback.set(True)
  196. if callbacks_list is None:
  197. yield []
  198. return
  199. try:
  200. yield callbacks_list
  201. finally:
  202. with _module_rlock:
  203. _is_writing = False
  204. _is_caused_by_callback.set(False)
  205. @contextlib.contextmanager
  206. def _reset_on_exception() -> Iterator[None]:
  207. """Clear all callbacks on any exception, suppressing it.
  208. This prevents infinite loops:
  209. * If we re-raise, an exception handler is likely to print
  210. the exception to the console and trigger callbacks again
  211. * If we log, we can't guarantee that this doesn't print
  212. to console.
  213. This is especially important for KeyboardInterrupt.
  214. """
  215. try:
  216. yield
  217. except BaseException as e:
  218. with _module_rlock:
  219. _stderr_callbacks.clear()
  220. _stdout_callbacks.clear()
  221. if isinstance(e, Exception):
  222. # We suppress Exceptions so that bugs in W&B code don't
  223. # cause the user's print() statements to raise errors.
  224. _logger.exception("Error in console callback, clearing all!")
  225. else:
  226. # Re-raise errors like KeyboardInterrupt.
  227. raise
  228. try:
  229. _patch(sys.stdout, _stdout_callbacks)
  230. _patch(sys.stderr, _stderr_callbacks)
  231. except Exception as _patch_exception_cause:
  232. _patch_exception = CannotCaptureConsoleError()
  233. _patch_exception.__cause__ = _patch_exception_cause