symmetric_run.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. """Symmetric Run for Ray."""
  2. import socket
  3. import subprocess
  4. import sys
  5. import time
  6. from typing import List
  7. import click
  8. import ray
  9. from ray._private.ray_constants import env_integer
  10. from ray._raylet import GcsClient
  11. from ray.exceptions import RpcError
  12. import psutil
  13. CLUSTER_WAIT_TIMEOUT = env_integer("RAY_SYMMETRIC_RUN_CLUSTER_WAIT_TIMEOUT", 30)
  14. def check_ray_already_started() -> bool:
  15. import ray._private.services as services
  16. # Try auto-detecting the Ray instance.
  17. running_gcs_addresses = services.find_gcs_addresses()
  18. return len(running_gcs_addresses) > 0
  19. def check_cluster_ready(nnodes, timeout=CLUSTER_WAIT_TIMEOUT):
  20. """Wait for all nodes to start.
  21. Raises an exception if the nodes don't start in time.
  22. """
  23. start_time = time.time()
  24. current_nodes = 1
  25. ray.init(ignore_reinit_error=True)
  26. while time.time() - start_time < timeout:
  27. time.sleep(5)
  28. current_nodes = len(ray.nodes())
  29. if current_nodes == nnodes:
  30. return True
  31. else:
  32. click.echo(
  33. f"Waiting for nodes to start... {current_nodes}/{nnodes} nodes started"
  34. )
  35. return False
  36. def check_head_node_ready(address: str, timeout=CLUSTER_WAIT_TIMEOUT):
  37. start_time = time.time()
  38. gcs_client = GcsClient(address=address)
  39. while time.time() - start_time < timeout:
  40. try:
  41. gcs_client.check_alive([], timeout=1)
  42. click.echo("Ray cluster is ready!")
  43. return True
  44. except RpcError:
  45. pass
  46. time.sleep(5)
  47. return False
  48. def curate_and_validate_ray_start_args(run_and_start_args: List[str]) -> List[str]:
  49. # Reparse the arguments to remove symmetric_run arguments.
  50. ctx = symmetric_run.make_context("_", run_and_start_args, resilient_parsing=True)
  51. cleaned_args = list(ctx.params["ray_args_and_entrypoint"])
  52. for arg in cleaned_args:
  53. if arg == "--head":
  54. raise click.ClickException("Cannot use --head option in symmetric_run.")
  55. if arg == "--node-ip-address":
  56. raise click.ClickException(
  57. "Cannot use --node-ip-address option in symmetric_run."
  58. )
  59. if arg == "--port":
  60. raise click.ClickException("Cannot use --port option in symmetric_run.")
  61. if arg == "--block":
  62. raise click.ClickException("Cannot use --block option in symmetric_run.")
  63. return cleaned_args
  64. @click.command(
  65. name="symmetric_run",
  66. context_settings={"ignore_unknown_options": True, "allow_extra_args": True},
  67. help="""Command to start Ray across all nodes and execute an entrypoint command.
  68. USAGE:
  69. ray symmetric-run --address ADDRESS
  70. [--min-nodes NUM_NODES] [RAY_START_OPTIONS] -- [ENTRYPOINT_COMMAND]
  71. DESCRIPTION:
  72. This command (1) starts a Ray cluster across all nodes,
  73. (2) runs a command on the head node, and (3) stops the Ray cluster.
  74. The '--' separator is required to distinguish between Ray start arguments
  75. and the entrypoint command. The --min-nodes option is optional and
  76. can be used to wait for a specific number of nodes to start.
  77. EXAMPLES:
  78. # Start Ray with default settings and run a Python script
  79. ray symmetric-run --address 127.0.0.1:6379 -- python my_script.py
  80. # Start Ray with specific head node and run a command
  81. ray symmetric-run --address 127.0.0.1:6379 --min-nodes 4 -- python train_model.py --epochs=100
  82. # Start Ray and run a multi-word command
  83. ray symmetric-run --address 127.0.0.1:6379 --min-nodes 4 --num-cpus=4 -- python -m my_module --config=prod
  84. RAY START OPTIONS:
  85. Most ray start command options are supported. Arguments that are not
  86. supported are: --head, --node-ip-address, --port, --block.
  87. SEPARATOR REQUIREMENT:
  88. The '--' separator is mandatory and must appear between Ray start
  89. arguments and the entrypoint command. This ensures clear separation
  90. between the two sets of arguments.
  91. """,
  92. )
  93. @click.option(
  94. "--address", required=True, type=str, help="The address of the Ray cluster."
  95. )
  96. @click.option(
  97. "--min-nodes",
  98. type=int,
  99. help="If provided, wait for this number of nodes to start.",
  100. )
  101. @click.argument("ray_args_and_entrypoint", nargs=-1, type=click.UNPROCESSED)
  102. def symmetric_run(address, min_nodes, ray_args_and_entrypoint):
  103. all_args = sys.argv[1:]
  104. if all_args and all_args[0] == "symmetric-run":
  105. all_args = all_args[1:]
  106. try:
  107. separator = all_args.index("--")
  108. except ValueError:
  109. raise click.ClickException(
  110. "No separator '--' found in arguments. Please use '--' to "
  111. "separate Ray start arguments and the entrypoint command."
  112. )
  113. run_and_start_args, entrypoint_on_head = (
  114. all_args[:separator],
  115. all_args[separator + 1 :],
  116. )
  117. ray_start_args = curate_and_validate_ray_start_args(run_and_start_args)
  118. min_nodes = 1 if min_nodes is None else min_nodes
  119. if not entrypoint_on_head:
  120. raise click.ClickException("No entrypoint command provided.")
  121. if check_ray_already_started():
  122. raise click.ClickException("Ray is already started on this node.")
  123. # 1. Parse address and check if we are on the head node.
  124. gcs_host_port = ray._common.network_utils.parse_address(address)
  125. if gcs_host_port is None:
  126. raise click.ClickException(
  127. f"Invalid address format: {address}, should be `host:port`"
  128. )
  129. gcs_host, gcs_port = gcs_host_port
  130. try:
  131. # AF_UNSPEC allows resolving both IPv4 and IPv6
  132. addrinfo = socket.getaddrinfo(
  133. gcs_host, gcs_port, socket.AF_UNSPEC, socket.SOCK_STREAM
  134. )
  135. resolved_gcs_host = addrinfo[0][4][0]
  136. except socket.gaierror:
  137. raise click.ClickException(f"Could not resolve hostname: {gcs_host}")
  138. my_ips = []
  139. for iface, addrs in psutil.net_if_addrs().items():
  140. for addr in addrs:
  141. # Look for AF_INET (IPv4) or AF_INET6 (IPv6)
  142. if addr.family in [
  143. socket.AddressFamily.AF_INET,
  144. socket.AddressFamily.AF_INET6,
  145. ]:
  146. my_ips.append(addr.address)
  147. if min_nodes > 1:
  148. # Ban localhost ips if we are not running on a single node
  149. # to avoid starting N head nodes
  150. my_ips = [ip for ip in my_ips if ip != "127.0.0.1" and ip != "::1"]
  151. is_head = resolved_gcs_host in my_ips
  152. result = None
  153. # 2. Start Ray and run commands.
  154. try:
  155. if is_head:
  156. # On the head node, start Ray, run the command, then stop Ray.
  157. click.echo("On head node. Starting Ray cluster head...")
  158. # Build the ray start command with all parameters
  159. ray_start_cmd = [
  160. "ray",
  161. "start",
  162. "--head",
  163. f"--node-ip-address={resolved_gcs_host}",
  164. f"--port={gcs_port}",
  165. *ray_start_args,
  166. ]
  167. # Start Ray head. This runs in the background and hides output.
  168. subprocess.run(ray_start_cmd, check=True, capture_output=True)
  169. click.echo("Head node started.")
  170. click.echo("=======================")
  171. if min_nodes > 1 and not check_cluster_ready(min_nodes):
  172. raise click.ClickException(
  173. "Timed out waiting for other nodes to start."
  174. )
  175. click.echo(
  176. f"Running command on head node: {entrypoint_on_head}",
  177. )
  178. click.echo("=======================")
  179. result = subprocess.run(entrypoint_on_head)
  180. click.echo("=======================")
  181. else:
  182. # On a worker node, start Ray and connect to the head.
  183. click.echo(f"On worker node. Connecting to Ray cluster at {address}...")
  184. if not check_head_node_ready(address):
  185. raise click.ClickException("Timed out waiting for head node to start.")
  186. # Build the ray start command for worker nodes with all parameters
  187. ray_start_cmd = [
  188. "ray",
  189. "start",
  190. "--address",
  191. address,
  192. "--block",
  193. *ray_start_args,
  194. ]
  195. # This command will block until the Ray cluster is stopped.
  196. subprocess.run(ray_start_cmd, check=True)
  197. except subprocess.CalledProcessError as e:
  198. click.echo(f"Failed to start Ray: {e}", err=True)
  199. if e.stdout:
  200. click.echo(f"stdout:\n{e.stdout.decode()}", err=True)
  201. if e.stderr:
  202. click.echo(f"stderr:\n{e.stderr.decode()}", err=True)
  203. except KeyboardInterrupt:
  204. # This can be triggered by ctrl-c on the user's side.
  205. click.echo("Interrupted by user.", err=True)
  206. finally:
  207. # Stop Ray cluster.
  208. subprocess.run(["ray", "stop"])
  209. # Propagate the exit code of the user script.
  210. if result is not None and result.returncode != 0:
  211. click.echo(f"Command failed with return code {result.returncode}", err=True)
  212. sys.exit(result.returncode)
  213. if __name__ == "__main__":
  214. symmetric_run()