| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308 |
- """Module for intercepting stdout/stderr.
- This patches the `write()` method of `stdout` and `stderr` on import.
- Once patched, it is not possible to unpatch or repatch, though individual
- callbacks can be removed.
- We assume that all other writing methods on the object delegate to `write()`,
- like `writelines()`. This is not guaranteed to be true, but it is true for
- common implementations. In particular, CPython's implementation of IOBase's
- `writelines()` delegates to `write()`.
- It is important to note that this technique interacts poorly with other
- code that performs similar patching if it also allows unpatching as this
- discards our modification. This is why we patch on import and do not support
- unpatching:
- with contextlib.redirect_stderr(...):
- from ... import console_capture
- # Here, everything works fine.
- # Here, callbacks are never called again.
- In particular, it does not work with some combinations of pytest's
- `capfd` / `capsys` fixtures and pytest's `--capture` option.
- """
- from __future__ import annotations
- import contextlib
- import contextvars
- import logging
- import sys
- import threading
- from collections.abc import Iterator
- from typing import IO, AnyStr, Callable, Protocol
- from . import wb_logging
- _logger = logging.getLogger(__name__)
- class CannotCaptureConsoleError(Exception):
- """The module failed to patch stdout or stderr."""
- class _WriteCallback(Protocol):
- """A callback that receives intercepted bytes or string data.
- This may be called from any thread, but is only called from one thread
- at a time.
- Note on errors: Any error raised during the callback will clear all
- callbacks. This means that if a user presses Ctrl-C at an unlucky time
- during a run, we will stop uploading console output---but it's not
- likely to be a problem unless something catches the KeyboardInterrupt.
- Regular Exceptions are caught and logged instead of bubbling up to the
- user's print() statements; other exceptions like KeyboardInterrupt are
- re-raised.
- Callbacks should handle all exceptions---a callback that raises any
- Exception is considered buggy.
- """
- def __call__(
- self,
- data: bytes | str,
- written: int,
- /,
- ) -> None:
- """Intercept data passed to `write()`.
- See the protocol docstring for information about exceptions.
- Args:
- data: The object passed to stderr's or stdout's `write()`.
- written: The number of bytes or characters written.
- This is the return value of `write()`.
- """
- # See _enter_callbacks() for why this is an RLock.
- _module_rlock = threading.RLock()
- # See _enter_callbacks().
- _is_writing = False
- _is_caused_by_callback = contextvars.ContextVar(
- "_is_caused_by_callback",
- default=False,
- )
- _patch_exception: CannotCaptureConsoleError | None = None
- def _maybe_raise_patch_exception() -> None:
- """Raise _patch_exception if it's set."""
- if _patch_exception:
- # Each `raise` call modifies the exception's __traceback__, so we must
- # reset the traceback to reuse the exception.
- #
- # See https://bugs.python.org/issue45924 for an example of the problem.
- raise _patch_exception.with_traceback(None)
- _next_callback_id: int = 1
- _stdout_callbacks: dict[int, _WriteCallback] = {}
- _stderr_callbacks: dict[int, _WriteCallback] = {}
- def capture_stdout(callback: _WriteCallback) -> Callable[[], None]:
- """Install a callback that runs after every write to sys.stdout.
- Args:
- callback: A callback to invoke after running `sys.stdout.write`.
- Returns:
- A function to uninstall the callback.
- Raises:
- CannotCaptureConsoleError: If patching failed on import.
- """
- _maybe_raise_patch_exception()
- with _module_rlock:
- return _insert_disposably(
- _stdout_callbacks,
- callback,
- )
- def capture_stderr(callback: _WriteCallback) -> Callable[[], None]:
- """Install a callback that runs after every write to sys.sdterr.
- Args:
- callback: A callback to invoke after running `sys.stderr.write`.
- Returns:
- A function to uninstall the callback.
- Raises:
- CannotCaptureConsoleError: If patching failed on import.
- """
- _maybe_raise_patch_exception()
- with _module_rlock:
- return _insert_disposably(
- _stderr_callbacks,
- callback,
- )
- def _insert_disposably(
- callback_dict: dict[int, _WriteCallback],
- callback: _WriteCallback,
- ) -> Callable[[], None]:
- global _next_callback_id
- id = _next_callback_id
- _next_callback_id += 1
- disposed = False
- def dispose() -> None:
- nonlocal disposed
- with _module_rlock:
- if disposed:
- return
- callback_dict.pop(id, None)
- disposed = True
- callback_dict[id] = callback
- return dispose
- def _patch(
- stdout_or_stderr: IO[AnyStr],
- callbacks: dict[int, _WriteCallback],
- ) -> None:
- orig_write: Callable[[AnyStr], int]
- def write_with_callbacks(s: AnyStr, /) -> int:
- n = orig_write(s)
- with contextlib.ExitStack() as stack:
- stack.enter_context(_reset_on_exception())
- stack.enter_context(wb_logging.log_to_all_runs())
- callbacks_list = stack.enter_context(_enter_callbacks(callbacks))
- for cb in callbacks_list:
- cb(s, n)
- return n
- orig_write = stdout_or_stderr.write
- # mypy==1.14.1 fails to type-check this:
- # Incompatible types in assignment (expression has type
- # "Callable[[bytes], int]", variable has type overloaded function)
- stdout_or_stderr.write = write_with_callbacks # type: ignore
- @contextlib.contextmanager
- def _enter_callbacks(
- callbacks: dict[int, _WriteCallback],
- ) -> Iterator[list[_WriteCallback]]:
- """Returns a list of callbacks to invoke.
- This prevents deadlocks and some infinite loops by returning an empty list
- when:
- * A callback prints
- * A callback blocks on a thread that's printing
- * A callback schedules an async task that prints
- It is impossible to prevent all infinite loops: a callback could put
- a message into a queue, causing an unrelated thread to print later,
- invoking the same callback and repeating forever.
- """
- global _is_writing
- # NOTE 1: _is_writing
- #
- # The global _is_writing variable is necessary despite the contextvar
- # because it's possible to create a thread without copying the context.
- # This is the default behavior for threading.Thread() before Python 3.14.
- #
- # A side effect of it is that when multiple threads print simultaneously,
- # some messages will not be captured.
- #
- # NOTE 2: _module_rlock
- #
- # We use a reentrant lock primarily because GC can trigger on the current
- # thread at any time. Since GC can run arbitrary code via `__del__`, it may
- # print and hit this lock while it's already held.
- #
- # Technically, even the simple methods called while holding the lock
- # could be patched to print, but a reentrant lock just turns this
- # deadlock into an infinite loop.
- #
- # Text printed during GC may or may not be captured.
- # Assuming that the GC thread cannot itself be stolen, GC does not thwart
- # the infinite loop protections:
- #
- # * If _is_writing == True, GC will not print or touch _is_writing
- # * If _is_writing == False, GC will set it to True and reset it to
- # False before giving back the thread
- #
- # Specifically, there is no situation where _is_writing "spontaneously"
- # changes from True to False.
- with _module_rlock:
- if _is_writing or _is_caused_by_callback.get():
- callbacks_list = None
- else:
- callbacks_list = list(callbacks.values())
- _is_writing = True
- _is_caused_by_callback.set(True)
- if callbacks_list is None:
- yield []
- return
- try:
- yield callbacks_list
- finally:
- with _module_rlock:
- _is_writing = False
- _is_caused_by_callback.set(False)
- @contextlib.contextmanager
- def _reset_on_exception() -> Iterator[None]:
- """Clear all callbacks on any exception, suppressing it.
- This prevents infinite loops:
- * If we re-raise, an exception handler is likely to print
- the exception to the console and trigger callbacks again
- * If we log, we can't guarantee that this doesn't print
- to console.
- This is especially important for KeyboardInterrupt.
- """
- try:
- yield
- except BaseException as e:
- with _module_rlock:
- _stderr_callbacks.clear()
- _stdout_callbacks.clear()
- if isinstance(e, Exception):
- # We suppress Exceptions so that bugs in W&B code don't
- # cause the user's print() statements to raise errors.
- _logger.exception("Error in console callback, clearing all!")
- else:
- # Re-raise errors like KeyboardInterrupt.
- raise
- try:
- _patch(sys.stdout, _stdout_callbacks)
- _patch(sys.stderr, _stderr_callbacks)
- except Exception as _patch_exception_cause:
- _patch_exception = CannotCaptureConsoleError()
- _patch_exception.__cause__ = _patch_exception_cause
|