| 123456789101112131415161718192021222324252627282930313233343536 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- """
- Global configuration flags for torch.distributed
- """
- import os
- import sys
- from typing import TYPE_CHECKING
- from torch.utils._config_module import Config, install_config_module
- __all__ = ["compile_on_one_rank", "use_torchcomms"]
- # When enabled, coordinates are computed at runtime via a custom op rather
- # than being baked in at compile time. This allows compiling on one rank
- # and running on multiple ranks.
- compile_on_one_rank: bool = bool(
- os.environ.get("TORCH_DISTRIBUTED_COMPILE_ON_ONE_RANK", False)
- )
- # When enabled, uses TorchComms for communication backend instead of the
- # traditional ProcessGroup backends (NCCL, Gloo, etc.).
- use_torchcomms: bool = Config(
- default=False,
- env_name_default="TORCH_DISTRIBUTED_USE_TORCHCOMMS",
- )
- if TYPE_CHECKING:
- from torch.utils._config_typing import * # noqa: F401, F403
- # adds patch, save_config, invalid config checks, etc
- install_config_module(sys.modules[__name__])
|