"""Logging configuration for the "wandb" logger. Most log statements in wandb are made in the context of a run and should be redirected to that run's log file (usually named 'debug.log'). This module provides a context manager to temporarily set the current run ID and registers a global handler for the 'wandb' logger that sends log statements to the right place. All functions in this module are threadsafe. NOTE: The pytest caplog fixture will fail to capture logs from the wandb logger because they are not propagated to the root logger. """ from __future__ import annotations import contextlib import contextvars import logging import pathlib from collections.abc import Iterator class _NotRunSpecific: """Sentinel for `not_run_specific()`.""" _NOT_RUN_SPECIFIC = _NotRunSpecific() _run_id: contextvars.ContextVar[str | _NotRunSpecific | None] = contextvars.ContextVar( "_run_id", default=None, ) _logger = logging.getLogger("wandb") def configure_wandb_logger() -> None: """Configures the global 'wandb' logger. The wandb logger is not intended to be customized by users. Instead, it is used as a mechanism to redirect log messages into wandb run-specific log files. This function is idempotent: calling it multiple times has the same effect. """ # Send all DEBUG and above messages to registered handlers. # # Per-run handlers can set different levels. _logger.setLevel(logging.DEBUG) # Do not propagate wandb logs to the root logger, which the user may have # configured to point elsewhere. All wandb log messages should go to a run's # log file. _logger.propagate = False # If no handlers are configured for the 'wandb' logger, don't activate the # "lastResort" handler which sends messages to stderr with a level of # WARNING by default. # # This occurs in wandb code that runs outside the context of a Run and # not as part of the CLI. # # Most such code uses the `termlog` / `termwarn` / `termerror` methods # to communicate with the user. When that code executes while a run is # active, its logger messages go to that run's log file. if not _logger.handlers: _logger.addHandler(logging.NullHandler()) @contextlib.contextmanager def log_to_run(run_id: str | None) -> Iterator[None]: """Direct all wandb log messages to the given run. Args: id: The current run ID, or None if actions in the context manager are not associated to a specific run. In the latter case, log messages will go to all runs. Usage: with wb_logging.run_id(...): ... # Log messages here go to the specified run's logger. """ token = _run_id.set(run_id) try: yield finally: _run_id.reset(token) @contextlib.contextmanager def log_to_all_runs() -> Iterator[None]: """Direct wandb log messages to all runs. Unlike `log_to_run(None)`, this indicates an intentional choice. This is often convenient to use as a decorator: @wb_logging.log_to_all_runs() def my_func(): ... # Log messages here go to the specified run's logger. """ token = _run_id.set(_NOT_RUN_SPECIFIC) try: yield finally: _run_id.reset(token) def add_file_handler(run_id: str, filepath: pathlib.Path) -> logging.Handler: """Direct log messages for a run to a file. Args: run_id: The run for which to create a log file. filepath: The file to write log messages to. Returns: The added handler which can then be configured further or removed from the 'wandb' logger directly. The default logging level is INFO. """ handler = logging.FileHandler(filepath) handler.setLevel(logging.INFO) handler.addFilter(_RunIDFilter(run_id)) handler.setFormatter( logging.Formatter( "%(asctime)s %(levelname)-7s %(threadName)-10s:%(process)d" " [%(filename)s:%(funcName)s():%(lineno)s]%(run_id_tag)s" " %(message)s" ) ) _logger.addHandler(handler) return handler class _RunIDFilter: """Filters out messages logged for a different run.""" def __init__(self, run_id: str) -> None: """Create a _RunIDFilter. Args: run_id: Allows messages when the run ID is this or None. """ self._run_id = run_id def filter(self, record: logging.LogRecord) -> bool: """Modify a log record and return whether it matches the run.""" run_id = _run_id.get() if run_id is None: record.run_id_tag = " [no run ID]" return True elif isinstance(run_id, _NotRunSpecific): record.run_id_tag = " [all runs]" return True else: record.run_id_tag = "" return run_id == self._run_id