| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242 |
- import logging
- import os
- from dataclasses import dataclass
- from datetime import timedelta
- from typing import Optional
- import torch
- import torch.distributed as dist
- from packaging.version import Version
- import ray
- from ray._common.network_utils import build_address
- from ray._private import ray_constants
- from ray.air._internal.device_manager import register_custom_torch_dist_backend
- from ray.exceptions import GetTimeoutError
- from ray.train._internal.base_worker_group import BaseWorkerGroup
- from ray.train._internal.utils import get_address_and_port
- from ray.train.backend import Backend, BackendConfig
- from ray.train.constants import (
- DEFAULT_TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
- TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
- )
- from ray.util import PublicAPI
- logger = logging.getLogger(__name__)
- class TorchConfigContextManager:
- def __enter__(self):
- # Set default cuda device
- if torch.cuda.is_available():
- device = ray.train.torch.get_device()
- if device.type == "cuda":
- torch.cuda.set_device(device)
- def __exit__(self, type, value, traceback):
- # Propagate exceptions if any
- return False
- @PublicAPI(stability="stable")
- @dataclass
- class TorchConfig(BackendConfig):
- """Configuration for torch process group setup.
- See https://pytorch.org/docs/stable/distributed.html for more info.
- Args:
- backend: The backend to use for training.
- See ``torch.distributed.init_process_group`` for more info and
- valid values.
- If set to None, nccl will be used if GPUs are requested, else gloo
- will be used.
- init_method: The initialization method to use. Either "env"
- for environment variable initialization or "tcp" for TCP
- initialization. Defaults to "env".
- timeout_s: Seconds for process group operations to timeout.
- """
- backend: Optional[str] = None
- init_method: str = "env"
- timeout_s: int = 1800
- @property
- def backend_cls(self):
- return _TorchBackend
- @property
- def train_func_context(self):
- return TorchConfigContextManager
- def _is_backend_nccl(backend: str) -> bool:
- # Check containment because comma separated lists of backends like cpu:gloo,cuda:nccl are supported.
- return backend == "nccl" or any(
- item.split(":")[1] == "nccl"
- for item in backend.split(",")
- if item.startswith("cuda:")
- )
- def _setup_torch_process_group(
- backend: str,
- world_rank: int,
- world_size: int,
- init_method: str,
- timeout_s: int = 1800,
- ):
- """Connects the distributed PyTorch backend.
- Args:
- backend: The backend (nccl, gloo, etc.) to use for training.
- world_rank: Rank of the current worker.
- world_size: Number of workers participating in the job.
- init_method: URL specifying how to initialize the process group.
- timeout_s: Seconds for process group operations to timeout.
- """
- if world_rank == 0:
- logger.info(
- f"Setting up process group for: {init_method} [rank={world_rank}, "
- f"world_size={world_size}]"
- )
- else:
- logger.debug(
- f"Setting up process group for: {init_method} [rank={world_rank}, "
- f"world_size={world_size}]"
- )
- logger.debug(f"using {backend}")
- if _is_backend_nccl(backend):
- # See https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/distributed/distributed_c10d.py#L803-L823 # noqa: E501
- # We do not use TORCH_NCCL_BLOCKING_WAIT due to performance overhead.
- if Version(torch.__version__) < Version("2.2.0"):
- TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR = "NCCL_ASYNC_ERROR_HANDLING"
- TORCH_NCCL_BLOCKING_WAIT_ENV_VAR = "NCCL_BLOCKING_WAIT"
- else:
- TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR = "TORCH_NCCL_ASYNC_ERROR_HANDLING"
- TORCH_NCCL_BLOCKING_WAIT_ENV_VAR = "TORCH_NCCL_BLOCKING_WAIT"
- if (
- TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR not in os.environ
- and TORCH_NCCL_BLOCKING_WAIT_ENV_VAR not in os.environ
- ):
- logger.debug(
- f"Setting {TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR}=1 to fail if NCCL collective communication operations are timing out. " # noqa: E501
- f"To override this behavior, you can set {TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR}=0." # noqa: E501
- )
- os.environ[TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR] = "1"
- elif backend == "hccl":
- register_custom_torch_dist_backend(backend)
- dist.init_process_group(
- backend=backend,
- init_method=init_method,
- rank=world_rank,
- world_size=world_size,
- timeout=timedelta(seconds=timeout_s),
- )
- def _shutdown_torch(destroy_process_group=False):
- from ray.air._internal.torch_utils import get_devices
- devices = get_devices()
- if destroy_process_group:
- dist.destroy_process_group()
- if torch.cuda.is_available():
- for device in devices:
- with torch.cuda.device(device):
- torch.cuda.empty_cache()
- def _set_torch_distributed_env_vars():
- # Same env vars as in
- # https://pytorch.org/docs/stable/elastic/run.html#environment-variables
- from ray.train.torch import get_device
- context = ray.train.get_context()
- os.environ["LOCAL_RANK"] = str(context.get_local_rank())
- os.environ["RANK"] = str(context.get_world_rank())
- os.environ["LOCAL_WORLD_SIZE"] = str(context.get_local_world_size())
- os.environ["WORLD_SIZE"] = str(context.get_world_size())
- os.environ["NODE_RANK"] = str(context.get_node_rank())
- # Makes sure Hugging Face Accelerate uses the correct device
- device = get_device()
- os.environ["ACCELERATE_TORCH_DEVICE"] = str(device)
- class _TorchBackend(Backend):
- share_cuda_visible_devices: bool = True
- def on_start(self, worker_group: BaseWorkerGroup, backend_config: TorchConfig):
- if dist.is_available():
- # Set the appropriate training backend.
- if backend_config.backend is None:
- resources = worker_group.get_resources_per_worker()
- num_gpus_per_worker = resources.get("GPU", 0)
- if num_gpus_per_worker > 0:
- backend = "nccl"
- else:
- backend = "gloo"
- else:
- backend = backend_config.backend
- master_addr, master_port = worker_group.execute_single(
- 0, get_address_and_port
- )
- if backend_config.init_method == "env":
- def set_env_vars(addr, port):
- os.environ["MASTER_ADDR"] = addr
- os.environ["MASTER_PORT"] = str(port)
- worker_group.execute(set_env_vars, addr=master_addr, port=master_port)
- url = "env://"
- elif backend_config.init_method == "tcp":
- url = f"tcp://{build_address(master_addr, master_port)}"
- else:
- raise ValueError(
- f"The provided init_method ("
- f"{backend_config.init_method}) is not supported. Must "
- f"be either 'env' or 'tcp'."
- )
- setup_futures = []
- for i in range(len(worker_group)):
- setup_futures.append(
- worker_group.execute_single_async(
- i,
- _setup_torch_process_group,
- backend=backend,
- world_rank=i,
- world_size=len(worker_group),
- init_method=url,
- timeout_s=backend_config.timeout_s,
- )
- )
- ray.get(setup_futures)
- else:
- raise RuntimeError("Distributed torch is not available.")
- def on_shutdown(self, worker_group: BaseWorkerGroup, backend_config):
- futures = worker_group.execute_async(
- _shutdown_torch,
- destroy_process_group=len(worker_group) > 1,
- )
- timeout_s = ray_constants.env_integer(
- TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
- DEFAULT_TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
- )
- try:
- ray.get(futures, timeout=timeout_s)
- except GetTimeoutError:
- logger.warning(
- f"Torch process group shutdown timed out after {timeout_s} seconds"
- )
- def on_training_start(
- self, worker_group: BaseWorkerGroup, backend_config: BackendConfig
- ):
- worker_group.execute(_set_torch_distributed_env_vars)
|