logger.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import logging
  4. import time
  5. from collections.abc import Callable
  6. from typing import Any, TypeVar
  7. from typing_extensions import ParamSpec
  8. from uuid import uuid4
  9. import torch.distributed.c10d_logger as c10d_logger
  10. from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME
  11. logger = logging.getLogger()
  12. __all__: list[str] = []
  13. # pyrefly: ignore [unknown-name]
  14. global _dcp_logger
  15. _dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME)
  16. _T = TypeVar("_T")
  17. _P = ParamSpec("_P")
  18. def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]:
  19. """
  20. Extracts log data from dcp method args
  21. """
  22. msg_dict = {}
  23. # checkpoint ID can be passed in through the serializer or through the checkpoint id directly
  24. storage_writer = kwargs.get("storage_writer")
  25. storage_reader = kwargs.get("storage_reader")
  26. planner = kwargs.get("planner")
  27. checkpoint_id = kwargs.get("checkpoint_id")
  28. if not checkpoint_id and (serializer := storage_writer or storage_reader):
  29. checkpoint_id = getattr(serializer, "checkpoint_id", None)
  30. msg_dict["checkpoint_id"] = (
  31. # pyrefly: ignore [unsupported-operation]
  32. str(checkpoint_id) if checkpoint_id is not None else checkpoint_id
  33. )
  34. # Uniquely identify a _dcp_method_logger wrapped function call.
  35. msg_dict["uuid"] = str(uuid4().int)
  36. if storage_writer:
  37. msg_dict["storage_writer"] = storage_writer.__class__.__name__
  38. if storage_reader:
  39. msg_dict["storage_reader"] = storage_reader.__class__.__name__
  40. if planner:
  41. msg_dict["planner"] = planner.__class__.__name__
  42. return msg_dict
  43. def _get_msg_dict(func_name, *args, **kwargs) -> dict[str, Any]:
  44. msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs)
  45. msg_dict.update(c10d_logger._get_msg_dict(func_name, *args, **kwargs))
  46. return msg_dict
  47. def _dcp_method_logger(
  48. log_exceptions: bool = False, **wrapper_kwargs: Any
  49. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: # pyre-ignore
  50. """This method decorator logs the start, end, and exception of wrapped events."""
  51. def decorator(func: Callable[_P, _T]):
  52. @functools.wraps(func)
  53. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  54. msg_dict = _get_msg_dict(
  55. func.__name__, *args, **{**wrapper_kwargs, **kwargs}
  56. )
  57. # log start event
  58. msg_dict["event"] = "start"
  59. t0 = time.time_ns()
  60. msg_dict["time"] = t0
  61. msg_dict["log_exceptions"] = log_exceptions
  62. _dcp_logger.debug(msg_dict)
  63. # exceptions
  64. try:
  65. result = func(*args, **kwargs)
  66. except BaseException as error:
  67. if log_exceptions:
  68. msg_dict["event"] = "exception"
  69. msg_dict["error"] = f"{error}"
  70. msg_dict["time"] = time.time_ns()
  71. _dcp_logger.error(msg_dict)
  72. raise
  73. # end event
  74. msg_dict["event"] = "end"
  75. t1 = time.time_ns()
  76. msg_dict["time"] = time.time_ns()
  77. msg_dict["times_spent"] = t1 - t0
  78. _dcp_logger.debug(msg_dict)
  79. return result
  80. return wrapper
  81. return decorator
  82. def _init_logger(rank: int):
  83. logger.setLevel(logging.INFO)
  84. ch = logging.StreamHandler()
  85. ch.setLevel(logging.INFO)
  86. formatter = logging.Formatter(
  87. f"[{rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s"
  88. )
  89. ch.setFormatter(formatter)
  90. logger.addHandler(ch)