__init__.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # mypy: allow-untyped-defs
  2. import logging
  3. import sys
  4. import traceback
  5. import typing
  6. from datetime import timedelta
  7. import torch
  8. RankType = int | torch.SymInt
  9. log = logging.getLogger(__name__)
  10. def is_available() -> bool:
  11. """
  12. Return ``True`` if the distributed package is available.
  13. Otherwise,
  14. ``torch.distributed`` does not expose any other APIs. Currently,
  15. ``torch.distributed`` is available on Linux, MacOS and Windows. Set
  16. ``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source.
  17. Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows,
  18. ``USE_DISTRIBUTED=0`` for MacOS.
  19. """
  20. return hasattr(torch._C, "_c10d_init")
  21. if is_available() and not torch._C._c10d_init():
  22. raise RuntimeError("Failed to initialize torch.distributed")
  23. # Custom Runtime Errors thrown from the distributed package
  24. DistError = torch._C._DistError
  25. DistBackendError = torch._C._DistBackendError
  26. DistNetworkError = torch._C._DistNetworkError
  27. DistStoreError = torch._C._DistStoreError
  28. QueueEmptyError = torch._C._DistQueueEmptyError
  29. if is_available():
  30. from torch._C._distributed_c10d import (
  31. _broadcast_coalesced,
  32. _compute_bucket_assignment_by_size,
  33. _ControlCollectives,
  34. _DEFAULT_FIRST_BUCKET_BYTES,
  35. _make_nccl_premul_sum,
  36. _register_builtin_comm_hook,
  37. _register_comm_hook,
  38. _StoreCollectives,
  39. _test_python_store,
  40. _verify_params_across_processes,
  41. Backend as _Backend,
  42. BuiltinCommHookType,
  43. DebugLevel,
  44. FileStore,
  45. get_debug_level,
  46. GradBucket,
  47. Logger,
  48. PrefixStore,
  49. ProcessGroup as ProcessGroup,
  50. Reducer,
  51. set_debug_level,
  52. set_debug_level_from_env,
  53. Store,
  54. TCPStore,
  55. Work as _Work,
  56. )
  57. def _make_distributed_pdb():
  58. """
  59. Supports using PDB from inside a multiprocessing child process.
  60. Usage:
  61. _make_distributed_pdb().set_trace()
  62. """
  63. # Lazy import pdb only if we set breakpoints.
  64. import pdb
  65. class _DistributedPdb(pdb.Pdb):
  66. def interaction(self, *args, **kwargs):
  67. _stdin = sys.stdin
  68. try:
  69. with open("/dev/stdin") as sys.stdin:
  70. pdb.Pdb.interaction(self, *args, **kwargs)
  71. finally:
  72. sys.stdin = _stdin
  73. return _DistributedPdb()
  74. _breakpoint_cache: dict[int, typing.Any] = {}
  75. def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600):
  76. """
  77. Set a breakpoint, but only on a single rank. All other ranks will wait for you to be
  78. done with the breakpoint before continuing.
  79. Args:
  80. rank (int): Which rank to break on. Default: ``0``
  81. skip (int): Skip the first ``skip`` calls to this breakpoint. Default: ``0``.
  82. """
  83. if skip > 0:
  84. key = hash(str(traceback.format_exc()))
  85. counter = _breakpoint_cache.get(key, 0) + 1
  86. _breakpoint_cache[key] = counter
  87. if counter <= skip:
  88. log.warning("Skip the breakpoint, counter=%d", counter)
  89. return
  90. # avoid having the default timeout (if short) interrupt your debug session
  91. if timeout_s is not None:
  92. for group in torch.distributed.distributed_c10d._pg_map:
  93. torch.distributed.distributed_c10d._set_pg_timeout(
  94. timedelta(seconds=timeout_s), group
  95. )
  96. if get_rank() == rank:
  97. pdb = _make_distributed_pdb()
  98. pdb.message(
  99. "\n!!! ATTENTION !!!\n\n"
  100. f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n"
  101. )
  102. pdb.set_trace()
  103. # If Meta/Python keys are in the TLS, we want to make sure that we ignore them
  104. # and hit the (default) CPU/CUDA implementation of barrier.
  105. meta_in_tls = torch._C._meta_in_tls_dispatch_include()
  106. guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
  107. torch._C._set_meta_in_tls_dispatch_include(False)
  108. try:
  109. barrier()
  110. finally:
  111. torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
  112. del guard
  113. if sys.platform != "win32":
  114. from torch._C._distributed_c10d import HashStore
  115. from .device_mesh import DeviceMesh, init_device_mesh
  116. # Variables prefixed with underscore are not auto imported
  117. # See the comment in `distributed_c10d.py` above `_backend` on why we expose
  118. # this.
  119. from .distributed_c10d import * # noqa: F403
  120. from .distributed_c10d import (
  121. _all_gather_base,
  122. _coalescing_manager,
  123. _CoalescingManager,
  124. _create_process_group_wrapper,
  125. _get_process_group_name,
  126. _rank_not_in_group,
  127. _reduce_scatter_base,
  128. _time_estimator,
  129. get_node_local_rank,
  130. )
  131. from .remote_device import _remote_device
  132. from .rendezvous import (
  133. _create_store_from_options,
  134. register_rendezvous_handler,
  135. rendezvous,
  136. )
  137. set_debug_level_from_env()
  138. else:
  139. # This stub is sufficient to get
  140. # python test/test_public_bindings.py -k test_correct_module_names
  141. # working even when USE_DISTRIBUTED=0. Feel free to add more
  142. # stubs as necessary.
  143. # We cannot define stubs directly because they confuse pyre
  144. class _ProcessGroupStub:
  145. pass
  146. sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined]