config.py 999 B

123456789101112131415161718192021222324252627282930313233343536
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. """
  3. Global configuration flags for torch.distributed
  4. """
  5. import os
  6. import sys
  7. from typing import TYPE_CHECKING
  8. from torch.utils._config_module import Config, install_config_module
  9. __all__ = ["compile_on_one_rank", "use_torchcomms"]
  10. # When enabled, coordinates are computed at runtime via a custom op rather
  11. # than being baked in at compile time. This allows compiling on one rank
  12. # and running on multiple ranks.
  13. compile_on_one_rank: bool = bool(
  14. os.environ.get("TORCH_DISTRIBUTED_COMPILE_ON_ONE_RANK", False)
  15. )
  16. # When enabled, uses TorchComms for communication backend instead of the
  17. # traditional ProcessGroup backends (NCCL, Gloo, etc.).
  18. use_torchcomms: bool = Config(
  19. default=False,
  20. env_name_default="TORCH_DISTRIBUTED_USE_TORCHCOMMS",
  21. )
  22. if TYPE_CHECKING:
  23. from torch.utils._config_typing import * # noqa: F401, F403
  24. # adds patch, save_config, invalid config checks, etc
  25. install_config_module(sys.modules[__name__])