launch_and_verify_cluster.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. """
  2. This script automates the process of launching and verifying a Ray cluster using a given
  3. cluster configuration file. It also handles cluster cleanup before and after the
  4. verification process. The script requires one command-line argument: the path to the
  5. cluster configuration file.
  6. Usage:
  7. python launch_and_verify_cluster.py [--no-config-cache] [--retries NUM_RETRIES]
  8. [--num-expected-nodes NUM_NODES] [--docker-override DOCKER_OVERRIDE]
  9. [--wheel-override WHEEL_OVERRIDE]
  10. <cluster_configuration_file_path>
  11. """
  12. import argparse
  13. import json
  14. import os
  15. import re
  16. import subprocess
  17. import sys
  18. import tempfile
  19. import time
  20. import traceback
  21. from pathlib import Path
  22. import boto3
  23. import botocore
  24. import yaml
  25. from google.cloud import storage
  26. import ray
  27. from ray.autoscaler._private.aws.config import RAY
  28. def check_arguments():
  29. """
  30. Check command line arguments and return the cluster configuration file path, the
  31. number of retries, the number of expected nodes, and the value of the
  32. --no-config-cache flag.
  33. """
  34. parser = argparse.ArgumentParser(description="Launch and verify a Ray cluster")
  35. parser.add_argument(
  36. "--no-config-cache",
  37. action="store_true",
  38. help="Pass the --no-config-cache flag to Ray CLI commands",
  39. )
  40. parser.add_argument(
  41. "--retries",
  42. type=int,
  43. default=3,
  44. help="Number of retries for verifying Ray is running (default: 3)",
  45. )
  46. parser.add_argument(
  47. "--num-expected-nodes",
  48. type=int,
  49. default=1,
  50. help="Number of nodes for verifying Ray is running (default: 1)",
  51. )
  52. parser.add_argument(
  53. "--docker-override",
  54. choices=["disable", "latest", "nightly", "commit"],
  55. default="disable",
  56. help="Override the docker image used for the head node and worker nodes",
  57. )
  58. parser.add_argument(
  59. "--wheel-override",
  60. type=str,
  61. default="",
  62. help="Override the wheel used for the head node and worker nodes",
  63. )
  64. parser.add_argument(
  65. "cluster_config", type=str, help="Path to the cluster configuration file"
  66. )
  67. args = parser.parse_args()
  68. assert not (
  69. args.docker_override != "disable" and args.wheel_override != ""
  70. ), "Cannot override both docker and wheel"
  71. return (
  72. args.cluster_config,
  73. args.retries,
  74. args.no_config_cache,
  75. args.num_expected_nodes,
  76. args.docker_override,
  77. args.wheel_override,
  78. )
  79. def get_docker_image(docker_override):
  80. """
  81. Get the docker image to use for the head node and worker nodes.
  82. Args:
  83. docker_override: The value of the --docker-override flag.
  84. Returns:
  85. The docker image to use for the head node and worker nodes, or None if not
  86. applicable.
  87. """
  88. if docker_override == "latest":
  89. return "rayproject/ray:latest-py310"
  90. elif docker_override == "nightly":
  91. return "rayproject/ray:nightly-py310"
  92. elif docker_override == "commit":
  93. if re.match("^[0-9]+.[0-9]+.[0-9]+$", ray.__version__):
  94. return f"rayproject/ray:{ray.__version__}.{ray.__commit__[:6]}-py310"
  95. else:
  96. print(
  97. "Error: docker image is only available for "
  98. f"release version, but we get: {ray.__version__}"
  99. )
  100. sys.exit(1)
  101. return None
  102. def check_file(file_path):
  103. """
  104. Check if the provided file path is valid and readable.
  105. Args:
  106. file_path: The path of the file to check.
  107. Raises:
  108. SystemExit: If the file is not readable or does not exist.
  109. """
  110. if not file_path.is_file() or not os.access(file_path, os.R_OK):
  111. print(f"Error: Cannot read cluster configuration file: {file_path}")
  112. sys.exit(1)
  113. def override_wheels_url(config_yaml, wheel_url):
  114. setup_commands = config_yaml.get("setup_commands", [])
  115. setup_commands.append(
  116. f'pip3 uninstall -y ray && pip3 install -U "ray[default] @ {wheel_url}"'
  117. )
  118. config_yaml["setup_commands"] = setup_commands
  119. def override_docker_image(config_yaml, docker_image):
  120. docker_config = config_yaml.get("docker", {})
  121. docker_config["image"] = docker_image
  122. docker_config["container_name"] = "ray_container"
  123. assert docker_config.get("head_image") is None, "Cannot override head_image"
  124. assert docker_config.get("worker_image") is None, "Cannot override worker_image"
  125. config_yaml["docker"] = docker_config
  126. def azure_authenticate():
  127. """Get Azure service principal credentials from AWS Secrets Manager and authenticate."""
  128. print("======================================")
  129. print("Getting Azure credentials from AWS Secrets Manager...")
  130. # Initialize AWS Secrets Manager client
  131. secrets_client = boto3.client("secretsmanager", region_name="us-west-2")
  132. # Get service principal credentials
  133. secret_response = secrets_client.get_secret_value(
  134. SecretId="azure-service-principal-oss-release"
  135. )
  136. secret = secret_response["SecretString"]
  137. client_id = json.loads(secret)["client_id"]
  138. tenant_id = json.loads(secret)["tenant_id"]
  139. # Get certificate
  140. cert_response = secrets_client.get_secret_value(
  141. SecretId="azure-service-principal-certificate"
  142. )
  143. cert = cert_response["SecretString"]
  144. # Write certificate to temp file
  145. tmp_dir = tempfile.mkdtemp()
  146. cert_path = os.path.join(tmp_dir, "azure_cert.pem")
  147. with open(cert_path, "w") as f:
  148. f.write(cert)
  149. # Login to Azure
  150. subprocess.check_call(
  151. [
  152. "az",
  153. "login",
  154. "--service-principal",
  155. "--username",
  156. client_id,
  157. "--certificate",
  158. cert_path,
  159. "--tenant",
  160. tenant_id,
  161. ]
  162. )
  163. def download_ssh_key_aws():
  164. """Download the ssh key from the S3 bucket to the local machine."""
  165. print("======================================")
  166. print("Downloading ssh key...")
  167. # Create a Boto3 client to interact with S3
  168. s3_client = boto3.client("s3", region_name="us-west-2")
  169. # Set the name of the S3 bucket and the key to download
  170. bucket_name = "aws-cluster-launcher-test"
  171. key_name = "ray-autoscaler_59_us-west-2.pem"
  172. # Download the key from the S3 bucket to a local file
  173. local_key_path = os.path.expanduser(f"~/.ssh/{key_name}")
  174. if not os.path.exists(os.path.dirname(local_key_path)):
  175. os.makedirs(os.path.dirname(local_key_path))
  176. s3_client.download_file(bucket_name, key_name, local_key_path)
  177. # Set permissions on the key file
  178. os.chmod(local_key_path, 0o400)
  179. def download_ssh_key_gcp():
  180. """Download the ssh key from the google cloud bucket to the local machine."""
  181. print("======================================")
  182. print("Downloading ssh key from GCP...")
  183. # Initialize the GCP storage client
  184. client = storage.Client()
  185. # Set the name of the GCS bucket and the blob (key) to download
  186. bucket_name = "gcp-cluster-launcher-release-test-ssh-keys"
  187. key_name = "ray-autoscaler_gcp_us-west1_anyscale-bridge-cd812d38_ubuntu_0.pem"
  188. # Get the bucket and blob
  189. bucket = client.get_bucket(bucket_name)
  190. blob = bucket.get_blob(key_name)
  191. # Download the blob to a local file
  192. local_key_path = os.path.expanduser(f"~/.ssh/{key_name}")
  193. if not os.path.exists(os.path.dirname(local_key_path)):
  194. os.makedirs(os.path.dirname(local_key_path))
  195. blob.download_to_filename(local_key_path)
  196. # Set permissions on the key file
  197. os.chmod(local_key_path, 0o400)
  198. def cleanup_cluster(config_yaml, cluster_config):
  199. """
  200. Clean up the cluster using the given cluster configuration file.
  201. Args:
  202. cluster_config: The path of the cluster configuration file.
  203. """
  204. print("======================================")
  205. print("Cleaning up cluster...")
  206. # We do multiple retries here because sometimes the cluster
  207. # fails to clean up properly, resulting in a non-zero exit code (e.g.
  208. # when processes have to be killed forcefully).
  209. last_error = None
  210. num_tries = 3
  211. for i in range(num_tries):
  212. try:
  213. env = os.environ.copy()
  214. env.pop("PYTHONPATH", None)
  215. subprocess.run(
  216. ["ray", "down", "-v", "-y", str(cluster_config)],
  217. check=True,
  218. capture_output=True,
  219. env=env,
  220. )
  221. cleanup_security_groups(config_yaml)
  222. # Final success
  223. return
  224. except subprocess.CalledProcessError as e:
  225. print(f"ray down fails[{i+1}/{num_tries}]: ")
  226. print(e.output.decode("utf-8"))
  227. # Print full traceback
  228. traceback.print_exc()
  229. # Print stdout and stderr from ray down
  230. print(f"stdout:\n{e.stdout.decode('utf-8')}")
  231. print(f"stderr:\n{e.stderr.decode('utf-8')}")
  232. last_error = e
  233. raise last_error
  234. def cleanup_security_group(ec2_client, id):
  235. retry = 0
  236. while retry < 10:
  237. try:
  238. ec2_client.delete_security_group(GroupId=id)
  239. return
  240. except botocore.exceptions.ClientError as e:
  241. if e.response["Error"]["Code"] == "DependencyViolation":
  242. sleep_time = 2**retry
  243. print(
  244. f"Waiting {sleep_time}s for the instance to be terminated before deleting the security group {id}" # noqa E501
  245. )
  246. time.sleep(sleep_time)
  247. retry += 1
  248. else:
  249. print(f"Error deleting security group: {e}")
  250. return
  251. def ensure_ssh_keys_azure():
  252. """
  253. Ensure that the SSH keys for Azure tests exist, and create them if they don't.
  254. """
  255. print("======================================")
  256. print("Ensuring Azure SSH keys exist...")
  257. private_key_path = os.path.expanduser("~/.ssh/ray-autoscaler-tests-ssh-key")
  258. public_key_path = os.path.expanduser("~/.ssh/ray-autoscaler-tests-ssh-key.pub")
  259. if os.path.exists(private_key_path) and os.path.exists(public_key_path):
  260. print("Azure SSH keys already exist.")
  261. return
  262. print("Azure SSH keys not found. Creating new keys...")
  263. ssh_dir = os.path.dirname(private_key_path)
  264. if not os.path.exists(ssh_dir):
  265. os.makedirs(ssh_dir, exist_ok=True)
  266. try:
  267. subprocess.run(
  268. [
  269. "ssh-keygen",
  270. "-t",
  271. "rsa",
  272. "-b",
  273. "4096",
  274. "-f",
  275. private_key_path,
  276. "-N",
  277. "",
  278. "-C",
  279. "ray-autoscaler-azure",
  280. ],
  281. check=True,
  282. capture_output=True,
  283. )
  284. print("Successfully created Azure SSH keys.")
  285. except subprocess.CalledProcessError as e:
  286. print("Error creating SSH keys:")
  287. print(f"stdout:\n{e.stdout.decode('utf-8')}")
  288. print(f"stderr:\n{e.stderr.decode('utf-8')}")
  289. sys.exit(1)
  290. def cleanup_security_groups(config):
  291. provider_type = config.get("provider", {}).get("type")
  292. if provider_type != "aws":
  293. return
  294. try:
  295. ec2_client = boto3.client("ec2", region_name="us-west-2")
  296. response = ec2_client.describe_security_groups(
  297. Filters=[
  298. {
  299. "Name": "tag-key",
  300. "Values": [RAY],
  301. },
  302. {
  303. "Name": "tag:ray-cluster-name",
  304. "Values": [config["cluster_name"]],
  305. },
  306. ]
  307. )
  308. for security_group in response["SecurityGroups"]:
  309. cleanup_security_group(ec2_client, security_group["GroupId"])
  310. except Exception as e:
  311. print(f"Error cleaning up security groups: {e}")
  312. def run_ray_commands(
  313. config_yaml,
  314. cluster_config,
  315. retries,
  316. no_config_cache,
  317. num_expected_nodes=1,
  318. ):
  319. """
  320. Run the necessary Ray commands to start a cluster, verify Ray is running, and clean
  321. up the cluster.
  322. Args:
  323. cluster_config: The path of the cluster configuration file.
  324. retries: The number of retries for the verification step.
  325. no_config_cache: Whether to pass the --no-config-cache flag to the ray CLI
  326. commands.
  327. """
  328. provider_type = config_yaml.get("provider", {}).get("type")
  329. if provider_type == "azure":
  330. azure_authenticate()
  331. print("======================================")
  332. print("Starting new cluster...")
  333. cmd = ["ray", "up", "-v", "-y"]
  334. if no_config_cache:
  335. cmd.append("--no-config-cache")
  336. cmd.append(str(cluster_config))
  337. env = os.environ.copy()
  338. env.pop("PYTHONPATH", None)
  339. try:
  340. subprocess.run(cmd, check=True, capture_output=True, env=env)
  341. except subprocess.CalledProcessError as e:
  342. print(e.output)
  343. # print stdout and stderr
  344. print(f"stdout:\n{e.stdout.decode('utf-8')}")
  345. print(f"stderr:\n{e.stderr.decode('utf-8')}")
  346. raise e
  347. print("======================================")
  348. print("Verifying Ray is running...")
  349. success = False
  350. count = 0
  351. while count < retries:
  352. try:
  353. cmd = [
  354. "ray",
  355. "exec",
  356. "-v",
  357. str(cluster_config),
  358. (
  359. 'python -c \'import ray; ray.init("localhost:6379");'
  360. + f" assert len(ray.nodes()) >= {num_expected_nodes}'"
  361. ),
  362. ]
  363. if no_config_cache:
  364. cmd.append("--no-config-cache")
  365. subprocess.run(cmd, check=True, env=env)
  366. success = True
  367. break
  368. except subprocess.CalledProcessError:
  369. count += 1
  370. print(f"Verification failed. Retry attempt {count} of {retries}...")
  371. time.sleep(60)
  372. if not success:
  373. print("======================================")
  374. print(
  375. f"Error: Verification failed after {retries} attempts. Cleaning up cluster "
  376. "before exiting..."
  377. )
  378. cleanup_cluster(config_yaml, cluster_config)
  379. print("======================================")
  380. print("Exiting script.")
  381. sys.exit(1)
  382. print("======================================")
  383. print("Ray verification successful.")
  384. cleanup_cluster(config_yaml, cluster_config)
  385. print("======================================")
  386. print("Finished executing script successfully.")
  387. if __name__ == "__main__":
  388. (
  389. cluster_config,
  390. retries,
  391. no_config_cache,
  392. num_expected_nodes,
  393. docker_override,
  394. wheel_override,
  395. ) = check_arguments()
  396. cluster_config = Path(cluster_config)
  397. check_file(cluster_config)
  398. print(f"Using cluster configuration file: {cluster_config}")
  399. print(f"Number of retries for 'verify ray is running' step: {retries}")
  400. print(f"Using --no-config-cache flag: {no_config_cache}")
  401. print(f"Number of expected nodes for 'verify ray is running': {num_expected_nodes}")
  402. config_yaml = yaml.safe_load(cluster_config.read_text())
  403. # Make the cluster name unique
  404. config_yaml["cluster_name"] = (
  405. config_yaml["cluster_name"] + "-" + str(int(time.time()))
  406. )
  407. print("======================================")
  408. print(f"Overriding ray wheel...: {wheel_override}")
  409. if wheel_override:
  410. override_wheels_url(config_yaml, wheel_override)
  411. print("======================================")
  412. print(f"Overriding docker image...: {docker_override}")
  413. docker_override_image = get_docker_image(docker_override)
  414. print(f"Using docker image: {docker_override_image}")
  415. if docker_override_image:
  416. override_docker_image(config_yaml, docker_override_image)
  417. provider_type = config_yaml.get("provider", {}).get("type")
  418. config_yaml["provider"]["cache_stopped_nodes"] = False
  419. if provider_type == "azure":
  420. ensure_ssh_keys_azure()
  421. elif provider_type == "aws":
  422. download_ssh_key_aws()
  423. config_yaml["provider"].pop("availability_zone", None)
  424. elif provider_type == "gcp":
  425. download_ssh_key_gcp()
  426. # Get the active account email
  427. account_email = (
  428. subprocess.run(
  429. ["gcloud", "config", "get-value", "account"],
  430. stdout=subprocess.PIPE,
  431. check=True,
  432. )
  433. .stdout.decode("utf-8")
  434. .strip()
  435. )
  436. print("Active account email:", account_email)
  437. # Get the current project ID
  438. project_id = (
  439. subprocess.run(
  440. ["gcloud", "config", "get-value", "project"],
  441. stdout=subprocess.PIPE,
  442. check=True,
  443. )
  444. .stdout.decode("utf-8")
  445. .strip()
  446. )
  447. print(
  448. f"Injecting GCP project '{project_id}' into cluster configuration file..."
  449. )
  450. config_yaml["provider"]["project_id"] = project_id
  451. elif provider_type == "vsphere":
  452. print("======================================")
  453. print("VSPHERE provider detected.")
  454. else:
  455. print("======================================")
  456. print("Provider type not recognized. Exiting script.")
  457. sys.exit(1)
  458. # Create a new temporary file and dump the updated configuration into it
  459. with tempfile.NamedTemporaryFile(suffix=".yaml") as temp:
  460. temp.write(yaml.dump(config_yaml).encode("utf-8"))
  461. temp.flush()
  462. cluster_config = Path(temp.name)
  463. run_ray_commands(
  464. config_yaml, cluster_config, retries, no_config_cache, num_expected_nodes
  465. )