logging.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. """Logging utilities for Dynamo and Inductor.
  2. This module provides specialized logging functionality including:
  3. - Step-based logging that prepends step numbers to log messages
  4. - Progress bar management for compilation phases
  5. - Centralized logger management for Dynamo and Inductor components
  6. The logging system helps track the progress of compilation phases and provides structured
  7. logging output for debugging and monitoring.
  8. """
  9. import itertools
  10. import logging
  11. from collections.abc import Callable
  12. from typing import Any
  13. from torch.hub import _Faketqdm, tqdm
  14. # Disable progress bar by default, not in dynamo config because otherwise get a circular import
  15. disable_progress = True
  16. # Return all loggers that torchdynamo/torchinductor is responsible for
  17. def get_loggers() -> list[logging.Logger]:
  18. return [
  19. logging.getLogger("torch.fx.experimental.symbolic_shapes"),
  20. logging.getLogger("torch._dynamo"),
  21. logging.getLogger("torch._inductor"),
  22. ]
  23. # Creates a logging function that logs a message with a step # prepended.
  24. # get_step_logger should be lazily called (i.e. at runtime, not at module-load time)
  25. # so that step numbers are initialized properly. e.g.:
  26. # @functools.cache
  27. # def _step_logger():
  28. # return get_step_logger(logging.getLogger(...))
  29. # def fn():
  30. # _step_logger()(logging.INFO, "msg")
  31. _step_counter = itertools.count(1)
  32. # Update num_steps if more phases are added: Dynamo, AOT, Backend
  33. # This is very inductor centric
  34. # _inductor.utils.has_triton() gives a circular import error here
  35. if not disable_progress:
  36. try:
  37. import triton # noqa: F401
  38. num_steps = 3
  39. except ImportError:
  40. num_steps = 2
  41. pbar = tqdm(total=num_steps, desc="torch.compile()", delay=0)
  42. def get_step_logger(logger: logging.Logger) -> Callable[..., None]:
  43. if not disable_progress:
  44. pbar.update(1)
  45. if not isinstance(pbar, _Faketqdm):
  46. pbar.set_postfix_str(f"{logger.name}")
  47. step = next(_step_counter)
  48. def log(level: int, msg: str, **kwargs: Any) -> None:
  49. if "stacklevel" not in kwargs:
  50. kwargs["stacklevel"] = 2
  51. logger.log(level, "Step %s: %s", step, msg, **kwargs)
  52. return log