distributed.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. """
  2. Manages process groups for distributed compilation in TorchDynamo.
  3. This module handles the initialization and management of process groups used for
  4. distributed compilation. Key features:
  5. - Lazy initialization of compilation process groups
  6. - Only creates groups when distributed mode is enabled and available
  7. - Integrates with compiler_collectives configuration setting
  8. - Provides a single global process group for compilation coordination
  9. The process group is created only when needed and if the distributed environment
  10. is properly initialized, making it safe to import and use this module even in
  11. non-distributed scenarios.
  12. """
  13. from typing import Optional
  14. import torch.distributed as dist
  15. from . import config
  16. _COMPILE_PG: Optional[dist.ProcessGroup] = None
  17. _GUARD_PG: Optional[dist.ProcessGroup] = None
  18. def get_compile_pg() -> Optional[dist.ProcessGroup]:
  19. if (
  20. config.enable_compiler_collectives
  21. and dist.is_available()
  22. and dist.is_initialized()
  23. ):
  24. global _COMPILE_PG
  25. if _COMPILE_PG is None:
  26. # , timeout=datetime.timedelta(seconds=2)
  27. _COMPILE_PG = dist.distributed_c10d._new_group_with_tag(
  28. pg_tag="pt2_compile_pg"
  29. )
  30. return _COMPILE_PG
  31. return None
  32. # NB: Unlike get_compile_pg, this is only called when guard collectives were
  33. # explicitly requested
  34. def get_guard_pg() -> Optional[dist.ProcessGroup]:
  35. if dist.is_available() and dist.is_initialized():
  36. global _GUARD_PG
  37. if _GUARD_PG is None:
  38. _GUARD_PG = dist.distributed_c10d._new_group_with_tag(pg_tag="pt2_guard_pg")
  39. return _GUARD_PG
  40. return None