config.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import os
  2. from dataclasses import dataclass
  3. from typing import Optional, Set
  4. from horovod.ray.runner import Coordinator
  5. from horovod.ray.utils import detect_nics, nics_to_env_var
  6. from horovod.runner.common.util import secret, timeout
  7. import ray
  8. from ray.train._internal.utils import update_env_vars
  9. from ray.train._internal.worker_group import Worker, WorkerGroup
  10. from ray.train.backend import Backend, BackendConfig
  11. from ray.util import PublicAPI
  12. @PublicAPI(stability="beta")
  13. @dataclass
  14. class HorovodConfig(BackendConfig):
  15. """Configurations for Horovod setup.
  16. See https://github.com/horovod/horovod/blob/master/horovod/runner/common/util/settings.py # noqa: E501
  17. Args:
  18. nics (Optional[Set[str]): Network interfaces that can be used for
  19. communication.
  20. verbose: Horovod logging verbosity.
  21. key (Optional[str]): Secret used for communication between workers.
  22. ssh_port (Optional[int]): Port for SSH server running on worker nodes.
  23. ssh_identity_file (Optional[str]): Path to the identity file to
  24. ssh into different hosts on the cluster.
  25. ssh_str (Optional[str]): CAUTION WHEN USING THIS. Private key
  26. file contents. Writes the private key to ssh_identity_file.
  27. timeout_s: Timeout parameter for Gloo rendezvous.
  28. placement_group_timeout_s: Timeout parameter for Ray
  29. Placement Group creation. Currently unused.
  30. """
  31. nics: Optional[Set[str]] = None
  32. verbose: int = 1
  33. key: Optional[str] = None
  34. ssh_port: Optional[int] = None
  35. ssh_identity_file: Optional[str] = None
  36. ssh_str: Optional[str] = None
  37. timeout_s: int = 300
  38. placement_group_timeout_s: int = 100
  39. @property
  40. def start_timeout(self):
  41. return timeout.Timeout(
  42. self.timeout_s,
  43. message="Timed out waiting for {activity}. Please "
  44. "check connectivity between servers. You "
  45. "may need to increase the --start-timeout "
  46. "parameter if you have too many servers.",
  47. )
  48. def __post_init__(self):
  49. if self.ssh_str and not os.path.exists(self.ssh_identity_file):
  50. with open(self.ssh_identity_file, "w") as f:
  51. os.chmod(self.ssh_identity_file, 0o600)
  52. f.write(self.ssh_str)
  53. if self.key is None:
  54. self.key = secret.make_secret_key()
  55. @property
  56. def backend_cls(self):
  57. return _HorovodBackend
  58. class _HorovodBackend(Backend):
  59. share_cuda_visible_devices: bool = True
  60. def on_start(self, worker_group: WorkerGroup, backend_config: HorovodConfig):
  61. # NOTE: Horovod backend uses V1 WorkerGroup directly instead of BaseWorkerGroup
  62. # because it requires direct access to worker metadata (node_id, hostname) that is
  63. # specific to the V1 implementation. Horovod does not support V2 WorkerGroup.
  64. # TODO(matt): Implement placement group strategies in BackendExecutor.
  65. # Initialize workers with Horovod environment variables
  66. setup_futures = []
  67. for rank in range(len(worker_group)):
  68. worker_node_id = worker_group.workers[rank].metadata.node_id
  69. setup_futures.append(
  70. worker_group.execute_single_async(
  71. rank,
  72. _init_env_vars,
  73. rank,
  74. len(worker_group),
  75. worker_node_id,
  76. )
  77. )
  78. ray.get(setup_futures)
  79. # Use Horovod Ray Coordinator
  80. # backend_config as settings
  81. self.coordinator = Coordinator(backend_config)
  82. # Get all the hostnames of all workers
  83. node_ids = [w.metadata.node_id for w in worker_group.workers]
  84. hostnames = [w.metadata.hostname for w in worker_group.workers]
  85. # Register each hostname to the coordinator. assumes the hostname
  86. # ordering is the same.
  87. for rank, (hostname, node_id) in enumerate(zip(hostnames, node_ids)):
  88. self.coordinator.register(hostname, node_id, rank)
  89. all_info = self.coordinator.finalize_registration()
  90. setup_futures = []
  91. for rank, local_cross_env_var in all_info.items():
  92. setup_futures.append(
  93. worker_group.execute_single_async(
  94. rank, update_env_vars, local_cross_env_var
  95. )
  96. )
  97. ray.get(setup_futures)
  98. coordinator_envs = self.coordinator.establish_rendezvous()
  99. # Get one worker from each host/node.
  100. node_worker_indexes = [node_ids.index(node_id) for node_id in set(node_ids)]
  101. node_workers = [
  102. _HorovodWorkerWrapper(worker_group.workers[worker_index])
  103. for worker_index in node_worker_indexes
  104. ]
  105. assert len(node_workers) == len(self.coordinator.hostnames)
  106. nics = detect_nics(
  107. backend_config,
  108. all_host_names=list(self.coordinator.hostnames),
  109. node_workers=node_workers,
  110. )
  111. coordinator_envs.update(nics_to_env_var(nics))
  112. worker_group.execute(update_env_vars, coordinator_envs)
  113. def _init_env_vars(world_rank: int, world_size: int, node_id: str):
  114. """Initialize Horovod environment variables."""
  115. os.environ["HOROVOD_HOSTNAME"] = node_id
  116. os.environ["HOROVOD_RANK"] = str(world_rank)
  117. os.environ["HOROVOD_SIZE"] = str(world_size)
  118. # TODO(tgaddair): temporary workaround for Horovod's worker discovery logic,
  119. # which requires passing in an extra parameter as part of the RayExecutor
  120. # API. This will be removed in the future as we migrate more of the
  121. # RayExecutor utils into Ray Train.
  122. # See: https://github.com/horovod/horovod/blob/v0.23.0/horovod/ray/driver_service.py#L9 # noqa: E501
  123. @dataclass
  124. class _HorovodWorkerWrapper:
  125. w: Worker
  126. @property
  127. def execute(self):
  128. w = self.w
  129. class ExecuteHandle:
  130. def remote(self, func, *args, **kwargs):
  131. _ = None
  132. return w.actor._RayTrainWorker__execute.remote(func, _, *args, **kwargs)
  133. return ExecuteHandle()