c10d_logger.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. #!/usr/bin/env python3
  2. # mypy: allow-untyped-defs
  3. # Copyright (c) Facebook, Inc. and its affiliates.
  4. # All rights reserved.
  5. #
  6. # This source code is licensed under the BSD-style license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. import functools
  9. import logging
  10. from collections.abc import Callable
  11. from typing import Any, TypeVar
  12. from typing_extensions import ParamSpec
  13. import torch
  14. import torch.distributed as dist
  15. from torch.distributed.logging_handlers import _log_handlers
  16. from torch.monitor import _WaitCounter
  17. __all__: list[str] = []
  18. _DEFAULT_DESTINATION = "default"
  19. def _get_or_create_logger(destination: str = _DEFAULT_DESTINATION) -> logging.Logger:
  20. logging_handler, log_handler_name = _get_logging_handler(destination)
  21. logger = logging.getLogger(f"c10d-{log_handler_name}")
  22. logger.setLevel(logging.DEBUG)
  23. formatter = logging.Formatter(
  24. "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
  25. )
  26. logging_handler.setFormatter(formatter)
  27. logger.propagate = False
  28. logger.addHandler(logging_handler)
  29. return logger
  30. def _get_logging_handler(
  31. destination: str = _DEFAULT_DESTINATION,
  32. ) -> tuple[logging.Handler, str]:
  33. log_handler = _log_handlers[destination]
  34. log_handler_name = f"{type(log_handler).__name__}-{destination}"
  35. return (log_handler, log_handler_name)
  36. # pyrefly: ignore [unknown-name]
  37. global _c10d_logger
  38. _c10d_logger = _get_or_create_logger()
  39. def _get_msg_dict(func_name, *args, **kwargs) -> dict[str, Any]:
  40. if dist.is_initialized():
  41. group = kwargs.get("group") or kwargs.get("process_group")
  42. msg_dict = {
  43. "func_name": f"{func_name}",
  44. "pg_name": f"{dist._get_process_group_name(kwargs.get('pg'))}", # type: ignore[arg-type]
  45. "backend": f"{dist.get_backend(group)}",
  46. "world_size": f"{dist.get_world_size()}",
  47. "group_size": f"{dist.get_world_size(group)}",
  48. "global_rank": f"{dist.get_rank()}",
  49. "local_rank": f"{dist.get_rank(group)}",
  50. }
  51. if msg_dict["backend"] == "nccl":
  52. nccl_version = torch.cuda.nccl.version()
  53. msg_dict["nccl_version"] = ".".join(str(v) for v in nccl_version)
  54. else:
  55. msg_dict = {
  56. "func_name": f"{func_name}",
  57. }
  58. return msg_dict
  59. _T = TypeVar("_T")
  60. _P = ParamSpec("_P")
  61. def _exception_logger(func: Callable[_P, _T]) -> Callable[_P, _T]:
  62. @functools.wraps(func)
  63. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  64. try:
  65. return func(*args, **kwargs)
  66. except Exception as error:
  67. msg_dict = _get_msg_dict(func.__name__, *args, **kwargs)
  68. msg_dict["error"] = f"{error}"
  69. _c10d_logger.debug(msg_dict)
  70. raise
  71. return wrapper
  72. def _time_logger(func: Callable[_P, _T]) -> Callable[_P, _T]:
  73. @functools.wraps(func)
  74. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  75. with _WaitCounter(f"pytorch.wait_counter.c10d.{func.__name__}").guard():
  76. func_return = func(*args, **kwargs)
  77. return func_return
  78. return wrapper