config.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import copy
  2. import logging
  3. import os
  4. from cryptography.hazmat.primitives import serialization as crypto_serialization
  5. from cryptography.hazmat.primitives.asymmetric import rsa
  6. from ray.autoscaler._private.constants import DISABLE_NODE_UPDATERS_KEY
  7. from ray.autoscaler._private.event_system import CreateClusterEvent, global_event_system
  8. from ray.autoscaler._private.util import check_legacy_fields
  9. PRIVATE_KEY_NAME = "ray-bootstrap-key.pem"
  10. PUBLIC_KEY_NAME = "ray_bootstrap_public_key.key"
  11. PRIVATE_KEY_PATH = os.path.expanduser(f"~/{PRIVATE_KEY_NAME}")
  12. PUBLIC_KEY_PATH = os.path.expanduser(f"~/{PUBLIC_KEY_NAME}")
  13. logger = logging.getLogger(__name__)
  14. def bootstrap_vsphere(config):
  15. # create a copy of the input config to modify
  16. config = copy.deepcopy(config)
  17. # Log warnings if user included deprecated `head_node` or `worker_nodes`
  18. # fields. Raise error if no `available_node_types`
  19. check_legacy_fields(config)
  20. # Configure SSH access, using an existing key pair if possible.
  21. config = configure_key_pair(config)
  22. # Configure docker run command to be executed on head and wroker nodes
  23. config = configure_run_options(config)
  24. global_event_system.execute_callback(
  25. CreateClusterEvent.ssh_keypair_downloaded,
  26. {"ssh_key_path": config["auth"]["ssh_private_key"]},
  27. )
  28. logger.info(f"{config}")
  29. return config
  30. def configure_key_pair(config):
  31. logger.info("Configuring keys for Ray Cluster Launcher to ssh into the head node.")
  32. if not os.path.exists(PRIVATE_KEY_PATH):
  33. logger.warning(
  34. "Private key file at path {} was not found".format(PRIVATE_KEY_PATH)
  35. )
  36. _create_ssh_keys()
  37. logger.info(
  38. f"New SSH key pair {PRIVATE_KEY_PATH} and {PUBLIC_KEY_PATH} created."
  39. )
  40. # updater.py file uses the following config to ssh onto the head node
  41. # Also, copies the file onto the head node
  42. config["auth"]["ssh_private_key"] = PRIVATE_KEY_PATH
  43. # The path where the public key should be copied onto the remote host
  44. public_key_remote_path = f"~/{PUBLIC_KEY_NAME}"
  45. # Copy the public key to the remote host
  46. config["file_mounts"][public_key_remote_path] = PUBLIC_KEY_PATH
  47. return config
  48. def configure_run_options(config):
  49. ssh_user = config["auth"]["ssh_user"]
  50. # By default enable TLS for Head-Worker grpc communication
  51. tls_enable = (
  52. 1 if config["provider"]["vsphere_config"].get("tls_enable", True) else 0
  53. )
  54. # Configure common run options
  55. if "run_options" not in config["docker"]:
  56. config["docker"]["run_options"] = []
  57. config["docker"]["run_options"].append(f"--env RAY_USE_TLS={tls_enable}")
  58. # Configure head_run_options
  59. if "head_run_options" not in config["docker"]:
  60. config["docker"]["head_run_options"] = []
  61. config["docker"]["head_run_options"].append(
  62. f"--env-file /home/{ssh_user}/svc-account-token.env"
  63. )
  64. # Configure worker_run_options
  65. if "worker_run_options" not in config["docker"]:
  66. config["docker"]["worker_run_options"] = []
  67. if tls_enable == 1:
  68. # Generate TLS cert and key for head and worker nodes.
  69. # This needs to be done before ray start command
  70. config["head_start_ray_commands"].insert(0, "sh /home/ray/gencert.sh")
  71. config["worker_start_ray_commands"].insert(0, "sh /home/ray/gencert.sh")
  72. config["docker"]["run_options"].append(
  73. f"-v /home/{ssh_user}/ca.crt:/home/ray/ca.crt"
  74. )
  75. config["docker"]["run_options"].append(
  76. f"-v /home/{ssh_user}/ca.key:/home/ray/ca.key"
  77. )
  78. config["docker"]["run_options"].append(
  79. f"-v /home/{ssh_user}/gencert.sh:/home/ray/gencert.sh"
  80. )
  81. config["docker"]["run_options"].append("--env RAY_TLS_CA_CERT=/home/ray/ca.crt")
  82. config["docker"]["run_options"].append(
  83. "--env RAY_TLS_SERVER_KEY=/home/ray/tls.key"
  84. )
  85. config["docker"]["run_options"].append(
  86. "--env RAY_TLS_SERVER_CERT=/home/ray/tls.crt"
  87. )
  88. return config
  89. def disable_node_updater(config):
  90. logger.info(
  91. "Disabling NodeUpdater threads as Cluster Operator is "
  92. + "responsible for Ray setup on nodes."
  93. )
  94. config["provider"][DISABLE_NODE_UPDATERS_KEY] = True
  95. return config
  96. def _create_ssh_keys():
  97. """Create SSH keys as specified"""
  98. # Create a private key
  99. private_key = rsa.generate_private_key(public_exponent=65537, key_size=4096)
  100. # Encode it in PEM format
  101. unencrypted_private_key = private_key.private_bytes(
  102. encoding=crypto_serialization.Encoding.PEM,
  103. format=crypto_serialization.PrivateFormat.TraditionalOpenSSL,
  104. encryption_algorithm=crypto_serialization.NoEncryption(),
  105. )
  106. # Create a public key
  107. public_key = private_key.public_key().public_bytes(
  108. encoding=crypto_serialization.Encoding.PEM,
  109. format=crypto_serialization.PublicFormat.SubjectPublicKeyInfo,
  110. )
  111. # Write keys
  112. with open(PRIVATE_KEY_PATH, "wb") as pvt_key:
  113. # manage access mode for the pvt key
  114. os.chmod(PRIVATE_KEY_PATH, 0o600)
  115. pvt_key.write(unencrypted_private_key)
  116. with open(PUBLIC_KEY_PATH, "wb") as pub_key:
  117. # manage access mode for the pvt key
  118. pub_key.write(public_key)