| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- """Logging configuration for the "wandb" logger.
- Most log statements in wandb are made in the context of a run and should be
- redirected to that run's log file (usually named 'debug.log'). This module
- provides a context manager to temporarily set the current run ID and registers
- a global handler for the 'wandb' logger that sends log statements to the right
- place.
- All functions in this module are threadsafe.
- NOTE: The pytest caplog fixture will fail to capture logs from the wandb logger
- because they are not propagated to the root logger.
- """
- from __future__ import annotations
- import contextlib
- import contextvars
- import logging
- import pathlib
- from collections.abc import Iterator
- class _NotRunSpecific:
- """Sentinel for `not_run_specific()`."""
- _NOT_RUN_SPECIFIC = _NotRunSpecific()
- _run_id: contextvars.ContextVar[str | _NotRunSpecific | None] = contextvars.ContextVar(
- "_run_id",
- default=None,
- )
- _logger = logging.getLogger("wandb")
- def configure_wandb_logger() -> None:
- """Configures the global 'wandb' logger.
- The wandb logger is not intended to be customized by users. Instead, it is
- used as a mechanism to redirect log messages into wandb run-specific log
- files.
- This function is idempotent: calling it multiple times has the same effect.
- """
- # Send all DEBUG and above messages to registered handlers.
- #
- # Per-run handlers can set different levels.
- _logger.setLevel(logging.DEBUG)
- # Do not propagate wandb logs to the root logger, which the user may have
- # configured to point elsewhere. All wandb log messages should go to a run's
- # log file.
- _logger.propagate = False
- # If no handlers are configured for the 'wandb' logger, don't activate the
- # "lastResort" handler which sends messages to stderr with a level of
- # WARNING by default.
- #
- # This occurs in wandb code that runs outside the context of a Run and
- # not as part of the CLI.
- #
- # Most such code uses the `termlog` / `termwarn` / `termerror` methods
- # to communicate with the user. When that code executes while a run is
- # active, its logger messages go to that run's log file.
- if not _logger.handlers:
- _logger.addHandler(logging.NullHandler())
- @contextlib.contextmanager
- def log_to_run(run_id: str | None) -> Iterator[None]:
- """Direct all wandb log messages to the given run.
- Args:
- id: The current run ID, or None if actions in the context manager are
- not associated to a specific run. In the latter case, log messages
- will go to all runs.
- Usage:
- with wb_logging.run_id(...):
- ... # Log messages here go to the specified run's logger.
- """
- token = _run_id.set(run_id)
- try:
- yield
- finally:
- _run_id.reset(token)
- @contextlib.contextmanager
- def log_to_all_runs() -> Iterator[None]:
- """Direct wandb log messages to all runs.
- Unlike `log_to_run(None)`, this indicates an intentional choice.
- This is often convenient to use as a decorator:
- @wb_logging.log_to_all_runs()
- def my_func():
- ... # Log messages here go to the specified run's logger.
- """
- token = _run_id.set(_NOT_RUN_SPECIFIC)
- try:
- yield
- finally:
- _run_id.reset(token)
- def add_file_handler(run_id: str, filepath: pathlib.Path) -> logging.Handler:
- """Direct log messages for a run to a file.
- Args:
- run_id: The run for which to create a log file.
- filepath: The file to write log messages to.
- Returns:
- The added handler which can then be configured further or removed
- from the 'wandb' logger directly.
- The default logging level is INFO.
- """
- handler = logging.FileHandler(filepath)
- handler.setLevel(logging.INFO)
- handler.addFilter(_RunIDFilter(run_id))
- handler.setFormatter(
- logging.Formatter(
- "%(asctime)s %(levelname)-7s %(threadName)-10s:%(process)d"
- " [%(filename)s:%(funcName)s():%(lineno)s]%(run_id_tag)s"
- " %(message)s"
- )
- )
- _logger.addHandler(handler)
- return handler
- class _RunIDFilter:
- """Filters out messages logged for a different run."""
- def __init__(self, run_id: str) -> None:
- """Create a _RunIDFilter.
- Args:
- run_id: Allows messages when the run ID is this or None.
- """
- self._run_id = run_id
- def filter(self, record: logging.LogRecord) -> bool:
- """Modify a log record and return whether it matches the run."""
- run_id = _run_id.get()
- if run_id is None:
- record.run_id_tag = " [no run ID]"
- return True
- elif isinstance(run_id, _NotRunSpecific):
- record.run_id_tag = " [all runs]"
- return True
- else:
- record.run_id_tag = ""
- return run_id == self._run_id
|