| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- import os
- from dataclasses import dataclass
- from typing import Optional, Set
- from horovod.ray.runner import Coordinator
- from horovod.ray.utils import detect_nics, nics_to_env_var
- from horovod.runner.common.util import secret, timeout
- import ray
- from ray.train._internal.utils import update_env_vars
- from ray.train._internal.worker_group import Worker, WorkerGroup
- from ray.train.backend import Backend, BackendConfig
- from ray.util import PublicAPI
- @PublicAPI(stability="beta")
- @dataclass
- class HorovodConfig(BackendConfig):
- """Configurations for Horovod setup.
- See https://github.com/horovod/horovod/blob/master/horovod/runner/common/util/settings.py # noqa: E501
- Args:
- nics (Optional[Set[str]): Network interfaces that can be used for
- communication.
- verbose: Horovod logging verbosity.
- key (Optional[str]): Secret used for communication between workers.
- ssh_port (Optional[int]): Port for SSH server running on worker nodes.
- ssh_identity_file (Optional[str]): Path to the identity file to
- ssh into different hosts on the cluster.
- ssh_str (Optional[str]): CAUTION WHEN USING THIS. Private key
- file contents. Writes the private key to ssh_identity_file.
- timeout_s: Timeout parameter for Gloo rendezvous.
- placement_group_timeout_s: Timeout parameter for Ray
- Placement Group creation. Currently unused.
- """
- nics: Optional[Set[str]] = None
- verbose: int = 1
- key: Optional[str] = None
- ssh_port: Optional[int] = None
- ssh_identity_file: Optional[str] = None
- ssh_str: Optional[str] = None
- timeout_s: int = 300
- placement_group_timeout_s: int = 100
- @property
- def start_timeout(self):
- return timeout.Timeout(
- self.timeout_s,
- message="Timed out waiting for {activity}. Please "
- "check connectivity between servers. You "
- "may need to increase the --start-timeout "
- "parameter if you have too many servers.",
- )
- def __post_init__(self):
- if self.ssh_str and not os.path.exists(self.ssh_identity_file):
- with open(self.ssh_identity_file, "w") as f:
- os.chmod(self.ssh_identity_file, 0o600)
- f.write(self.ssh_str)
- if self.key is None:
- self.key = secret.make_secret_key()
- @property
- def backend_cls(self):
- return _HorovodBackend
- class _HorovodBackend(Backend):
- share_cuda_visible_devices: bool = True
- def on_start(self, worker_group: WorkerGroup, backend_config: HorovodConfig):
- # NOTE: Horovod backend uses V1 WorkerGroup directly instead of BaseWorkerGroup
- # because it requires direct access to worker metadata (node_id, hostname) that is
- # specific to the V1 implementation. Horovod does not support V2 WorkerGroup.
- # TODO(matt): Implement placement group strategies in BackendExecutor.
- # Initialize workers with Horovod environment variables
- setup_futures = []
- for rank in range(len(worker_group)):
- worker_node_id = worker_group.workers[rank].metadata.node_id
- setup_futures.append(
- worker_group.execute_single_async(
- rank,
- _init_env_vars,
- rank,
- len(worker_group),
- worker_node_id,
- )
- )
- ray.get(setup_futures)
- # Use Horovod Ray Coordinator
- # backend_config as settings
- self.coordinator = Coordinator(backend_config)
- # Get all the hostnames of all workers
- node_ids = [w.metadata.node_id for w in worker_group.workers]
- hostnames = [w.metadata.hostname for w in worker_group.workers]
- # Register each hostname to the coordinator. assumes the hostname
- # ordering is the same.
- for rank, (hostname, node_id) in enumerate(zip(hostnames, node_ids)):
- self.coordinator.register(hostname, node_id, rank)
- all_info = self.coordinator.finalize_registration()
- setup_futures = []
- for rank, local_cross_env_var in all_info.items():
- setup_futures.append(
- worker_group.execute_single_async(
- rank, update_env_vars, local_cross_env_var
- )
- )
- ray.get(setup_futures)
- coordinator_envs = self.coordinator.establish_rendezvous()
- # Get one worker from each host/node.
- node_worker_indexes = [node_ids.index(node_id) for node_id in set(node_ids)]
- node_workers = [
- _HorovodWorkerWrapper(worker_group.workers[worker_index])
- for worker_index in node_worker_indexes
- ]
- assert len(node_workers) == len(self.coordinator.hostnames)
- nics = detect_nics(
- backend_config,
- all_host_names=list(self.coordinator.hostnames),
- node_workers=node_workers,
- )
- coordinator_envs.update(nics_to_env_var(nics))
- worker_group.execute(update_env_vars, coordinator_envs)
- def _init_env_vars(world_rank: int, world_size: int, node_id: str):
- """Initialize Horovod environment variables."""
- os.environ["HOROVOD_HOSTNAME"] = node_id
- os.environ["HOROVOD_RANK"] = str(world_rank)
- os.environ["HOROVOD_SIZE"] = str(world_size)
- # TODO(tgaddair): temporary workaround for Horovod's worker discovery logic,
- # which requires passing in an extra parameter as part of the RayExecutor
- # API. This will be removed in the future as we migrate more of the
- # RayExecutor utils into Ray Train.
- # See: https://github.com/horovod/horovod/blob/v0.23.0/horovod/ray/driver_service.py#L9 # noqa: E501
- @dataclass
- class _HorovodWorkerWrapper:
- w: Worker
- @property
- def execute(self):
- w = self.w
- class ExecuteHandle:
- def remote(self, func, *args, **kwargs):
- _ = None
- return w.actor._RayTrainWorker__execute.remote(func, _, *args, **kwargs)
- return ExecuteHandle()
|