config.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import logging
  2. import os
  3. from dataclasses import dataclass
  4. from datetime import timedelta
  5. from typing import Optional
  6. import torch
  7. import torch.distributed as dist
  8. from packaging.version import Version
  9. import ray
  10. from ray._common.network_utils import build_address
  11. from ray._private import ray_constants
  12. from ray.air._internal.device_manager import register_custom_torch_dist_backend
  13. from ray.exceptions import GetTimeoutError
  14. from ray.train._internal.base_worker_group import BaseWorkerGroup
  15. from ray.train._internal.utils import get_address_and_port
  16. from ray.train.backend import Backend, BackendConfig
  17. from ray.train.constants import (
  18. DEFAULT_TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
  19. TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
  20. )
  21. from ray.util import PublicAPI
  22. logger = logging.getLogger(__name__)
  23. class TorchConfigContextManager:
  24. def __enter__(self):
  25. # Set default cuda device
  26. if torch.cuda.is_available():
  27. device = ray.train.torch.get_device()
  28. if device.type == "cuda":
  29. torch.cuda.set_device(device)
  30. def __exit__(self, type, value, traceback):
  31. # Propagate exceptions if any
  32. return False
  33. @PublicAPI(stability="stable")
  34. @dataclass
  35. class TorchConfig(BackendConfig):
  36. """Configuration for torch process group setup.
  37. See https://pytorch.org/docs/stable/distributed.html for more info.
  38. Args:
  39. backend: The backend to use for training.
  40. See ``torch.distributed.init_process_group`` for more info and
  41. valid values.
  42. If set to None, nccl will be used if GPUs are requested, else gloo
  43. will be used.
  44. init_method: The initialization method to use. Either "env"
  45. for environment variable initialization or "tcp" for TCP
  46. initialization. Defaults to "env".
  47. timeout_s: Seconds for process group operations to timeout.
  48. """
  49. backend: Optional[str] = None
  50. init_method: str = "env"
  51. timeout_s: int = 1800
  52. @property
  53. def backend_cls(self):
  54. return _TorchBackend
  55. @property
  56. def train_func_context(self):
  57. return TorchConfigContextManager
  58. def _is_backend_nccl(backend: str) -> bool:
  59. # Check containment because comma separated lists of backends like cpu:gloo,cuda:nccl are supported.
  60. return backend == "nccl" or any(
  61. item.split(":")[1] == "nccl"
  62. for item in backend.split(",")
  63. if item.startswith("cuda:")
  64. )
  65. def _setup_torch_process_group(
  66. backend: str,
  67. world_rank: int,
  68. world_size: int,
  69. init_method: str,
  70. timeout_s: int = 1800,
  71. ):
  72. """Connects the distributed PyTorch backend.
  73. Args:
  74. backend: The backend (nccl, gloo, etc.) to use for training.
  75. world_rank: Rank of the current worker.
  76. world_size: Number of workers participating in the job.
  77. init_method: URL specifying how to initialize the process group.
  78. timeout_s: Seconds for process group operations to timeout.
  79. """
  80. if world_rank == 0:
  81. logger.info(
  82. f"Setting up process group for: {init_method} [rank={world_rank}, "
  83. f"world_size={world_size}]"
  84. )
  85. else:
  86. logger.debug(
  87. f"Setting up process group for: {init_method} [rank={world_rank}, "
  88. f"world_size={world_size}]"
  89. )
  90. logger.debug(f"using {backend}")
  91. if _is_backend_nccl(backend):
  92. # See https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/distributed/distributed_c10d.py#L803-L823 # noqa: E501
  93. # We do not use TORCH_NCCL_BLOCKING_WAIT due to performance overhead.
  94. if Version(torch.__version__) < Version("2.2.0"):
  95. TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR = "NCCL_ASYNC_ERROR_HANDLING"
  96. TORCH_NCCL_BLOCKING_WAIT_ENV_VAR = "NCCL_BLOCKING_WAIT"
  97. else:
  98. TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR = "TORCH_NCCL_ASYNC_ERROR_HANDLING"
  99. TORCH_NCCL_BLOCKING_WAIT_ENV_VAR = "TORCH_NCCL_BLOCKING_WAIT"
  100. if (
  101. TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR not in os.environ
  102. and TORCH_NCCL_BLOCKING_WAIT_ENV_VAR not in os.environ
  103. ):
  104. logger.debug(
  105. f"Setting {TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR}=1 to fail if NCCL collective communication operations are timing out. " # noqa: E501
  106. f"To override this behavior, you can set {TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR}=0." # noqa: E501
  107. )
  108. os.environ[TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR] = "1"
  109. elif backend == "hccl":
  110. register_custom_torch_dist_backend(backend)
  111. dist.init_process_group(
  112. backend=backend,
  113. init_method=init_method,
  114. rank=world_rank,
  115. world_size=world_size,
  116. timeout=timedelta(seconds=timeout_s),
  117. )
  118. def _shutdown_torch(destroy_process_group=False):
  119. from ray.air._internal.torch_utils import get_devices
  120. devices = get_devices()
  121. if destroy_process_group:
  122. dist.destroy_process_group()
  123. if torch.cuda.is_available():
  124. for device in devices:
  125. with torch.cuda.device(device):
  126. torch.cuda.empty_cache()
  127. def _set_torch_distributed_env_vars():
  128. # Same env vars as in
  129. # https://pytorch.org/docs/stable/elastic/run.html#environment-variables
  130. from ray.train.torch import get_device
  131. context = ray.train.get_context()
  132. os.environ["LOCAL_RANK"] = str(context.get_local_rank())
  133. os.environ["RANK"] = str(context.get_world_rank())
  134. os.environ["LOCAL_WORLD_SIZE"] = str(context.get_local_world_size())
  135. os.environ["WORLD_SIZE"] = str(context.get_world_size())
  136. os.environ["NODE_RANK"] = str(context.get_node_rank())
  137. # Makes sure Hugging Face Accelerate uses the correct device
  138. device = get_device()
  139. os.environ["ACCELERATE_TORCH_DEVICE"] = str(device)
  140. class _TorchBackend(Backend):
  141. share_cuda_visible_devices: bool = True
  142. def on_start(self, worker_group: BaseWorkerGroup, backend_config: TorchConfig):
  143. if dist.is_available():
  144. # Set the appropriate training backend.
  145. if backend_config.backend is None:
  146. resources = worker_group.get_resources_per_worker()
  147. num_gpus_per_worker = resources.get("GPU", 0)
  148. if num_gpus_per_worker > 0:
  149. backend = "nccl"
  150. else:
  151. backend = "gloo"
  152. else:
  153. backend = backend_config.backend
  154. master_addr, master_port = worker_group.execute_single(
  155. 0, get_address_and_port
  156. )
  157. if backend_config.init_method == "env":
  158. def set_env_vars(addr, port):
  159. os.environ["MASTER_ADDR"] = addr
  160. os.environ["MASTER_PORT"] = str(port)
  161. worker_group.execute(set_env_vars, addr=master_addr, port=master_port)
  162. url = "env://"
  163. elif backend_config.init_method == "tcp":
  164. url = f"tcp://{build_address(master_addr, master_port)}"
  165. else:
  166. raise ValueError(
  167. f"The provided init_method ("
  168. f"{backend_config.init_method}) is not supported. Must "
  169. f"be either 'env' or 'tcp'."
  170. )
  171. setup_futures = []
  172. for i in range(len(worker_group)):
  173. setup_futures.append(
  174. worker_group.execute_single_async(
  175. i,
  176. _setup_torch_process_group,
  177. backend=backend,
  178. world_rank=i,
  179. world_size=len(worker_group),
  180. init_method=url,
  181. timeout_s=backend_config.timeout_s,
  182. )
  183. )
  184. ray.get(setup_futures)
  185. else:
  186. raise RuntimeError("Distributed torch is not available.")
  187. def on_shutdown(self, worker_group: BaseWorkerGroup, backend_config):
  188. futures = worker_group.execute_async(
  189. _shutdown_torch,
  190. destroy_process_group=len(worker_group) > 1,
  191. )
  192. timeout_s = ray_constants.env_integer(
  193. TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
  194. DEFAULT_TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
  195. )
  196. try:
  197. ray.get(futures, timeout=timeout_s)
  198. except GetTimeoutError:
  199. logger.warning(
  200. f"Torch process group shutdown timed out after {timeout_s} seconds"
  201. )
  202. def on_training_start(
  203. self, worker_group: BaseWorkerGroup, backend_config: BackendConfig
  204. ):
  205. worker_group.execute(_set_torch_distributed_env_vars)