cluster_utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. import copy
  2. import json
  3. import logging
  4. import os
  5. import subprocess
  6. import tempfile
  7. import time
  8. from typing import Dict, Optional
  9. import yaml
  10. import ray
  11. import ray._private.services
  12. from ray._private import ray_constants
  13. from ray._private.client_mode_hook import disable_client_hook
  14. from ray._raylet import GcsClientOptions
  15. from ray.autoscaler._private.fake_multi_node.node_provider import FAKE_HEAD_NODE_ID
  16. from ray.util.annotations import DeveloperAPI
  17. logger = logging.getLogger(__name__)
  18. cluster_not_supported = os.name == "nt"
  19. @DeveloperAPI
  20. class AutoscalingCluster:
  21. """Create a local autoscaling cluster for testing.
  22. See test_autoscaler_fake_multinode.py for an end-to-end example.
  23. """
  24. def __init__(
  25. self,
  26. head_resources: dict,
  27. worker_node_types: dict,
  28. autoscaler_v2: bool = False,
  29. **config_kwargs,
  30. ):
  31. """Create the cluster.
  32. Args:
  33. head_resources: resources of the head node, including CPU.
  34. worker_node_types: autoscaler node types config for worker nodes.
  35. """
  36. self._head_resources = head_resources
  37. self._config = self._generate_config(
  38. head_resources,
  39. worker_node_types,
  40. autoscaler_v2=autoscaler_v2,
  41. **config_kwargs,
  42. )
  43. self._autoscaler_v2 = autoscaler_v2
  44. def _generate_config(
  45. self, head_resources, worker_node_types, autoscaler_v2=False, **config_kwargs
  46. ):
  47. base_config = yaml.safe_load(
  48. open(
  49. os.path.join(
  50. os.path.dirname(ray.__file__),
  51. "autoscaler/_private/fake_multi_node/example.yaml",
  52. )
  53. )
  54. )
  55. custom_config = copy.deepcopy(base_config)
  56. custom_config["available_node_types"] = worker_node_types
  57. custom_config["available_node_types"]["ray.head.default"] = {
  58. "resources": head_resources,
  59. "node_config": {},
  60. "max_workers": 0,
  61. }
  62. # Autoscaler v2 specific configs
  63. if autoscaler_v2:
  64. custom_config["provider"]["launch_multiple"] = True
  65. custom_config["provider"]["head_node_id"] = FAKE_HEAD_NODE_ID
  66. custom_config.update(config_kwargs)
  67. return custom_config
  68. def start(self, _system_config=None, override_env: Optional[Dict] = None):
  69. """Start the cluster.
  70. After this call returns, you can connect to the cluster with
  71. ray.init("auto").
  72. """
  73. subprocess.check_call(["ray", "stop", "--force"])
  74. _, fake_config = tempfile.mkstemp()
  75. with open(fake_config, "w") as f:
  76. f.write(json.dumps(self._config))
  77. cmd = [
  78. "ray",
  79. "start",
  80. "--autoscaling-config={}".format(fake_config),
  81. "--head",
  82. ]
  83. if "CPU" in self._head_resources:
  84. cmd.append("--num-cpus={}".format(self._head_resources.pop("CPU")))
  85. if "GPU" in self._head_resources:
  86. cmd.append("--num-gpus={}".format(self._head_resources.pop("GPU")))
  87. if "object_store_memory" in self._head_resources:
  88. cmd.append(
  89. "--object-store-memory={}".format(
  90. self._head_resources.pop("object_store_memory")
  91. )
  92. )
  93. if self._head_resources:
  94. cmd.append("--resources='{}'".format(json.dumps(self._head_resources)))
  95. if _system_config is not None:
  96. cmd.append(
  97. "--system-config={}".format(
  98. json.dumps(_system_config, separators=(",", ":"))
  99. )
  100. )
  101. env = os.environ.copy()
  102. env.update({"AUTOSCALER_UPDATE_INTERVAL_S": "1", "RAY_FAKE_CLUSTER": "1"})
  103. if self._autoscaler_v2:
  104. # Set the necessary environment variables for autoscaler v2.
  105. env.update(
  106. {
  107. "RAY_enable_autoscaler_v2": "1",
  108. "RAY_CLOUD_INSTANCE_ID": FAKE_HEAD_NODE_ID,
  109. "RAY_OVERRIDE_NODE_ID_FOR_TESTING": FAKE_HEAD_NODE_ID,
  110. }
  111. )
  112. if override_env:
  113. env.update(override_env)
  114. subprocess.check_call(cmd, env=env)
  115. def shutdown(self):
  116. """Terminate the cluster."""
  117. subprocess.check_call(["ray", "stop", "--force"])
  118. @DeveloperAPI
  119. class Cluster:
  120. def __init__(
  121. self,
  122. initialize_head: bool = False,
  123. connect: bool = False,
  124. head_node_args: dict = None,
  125. shutdown_at_exit: bool = True,
  126. ):
  127. """Initializes all services of a Ray cluster.
  128. Args:
  129. initialize_head: Automatically start a Ray cluster
  130. by initializing the head node. Defaults to False.
  131. connect: If `initialize_head=True` and `connect=True`,
  132. ray.init will be called with the address of this cluster
  133. passed in.
  134. head_node_args: Arguments to be passed into
  135. `start_ray_head` via `self.add_node`.
  136. shutdown_at_exit: If True, registers an exit hook
  137. for shutting down all started processes.
  138. """
  139. if cluster_not_supported:
  140. logger.warning(
  141. "Ray cluster mode is currently experimental and untested on "
  142. "Windows. If you are using it and running into issues please "
  143. "file a report at https://github.com/ray-project/ray/issues."
  144. )
  145. self.head_node = None
  146. self.worker_nodes = set()
  147. self.redis_address = None
  148. self.connected = False
  149. # Create a new global state accessor for fetching GCS table.
  150. self.global_state = ray._private.state.GlobalState()
  151. self._shutdown_at_exit = shutdown_at_exit
  152. if not initialize_head and connect:
  153. raise RuntimeError("Cannot connect to uninitialized cluster.")
  154. if initialize_head:
  155. head_node_args = head_node_args or {}
  156. self.add_node(**head_node_args)
  157. if connect:
  158. self.connect()
  159. @property
  160. def gcs_address(self):
  161. if self.head_node is None:
  162. return None
  163. return self.head_node.gcs_address
  164. @property
  165. def address(self):
  166. return self.gcs_address
  167. def connect(self, namespace=None):
  168. """Connect the driver to the cluster."""
  169. assert self.address is not None
  170. assert not self.connected
  171. output_info = ray.init(
  172. namespace=namespace,
  173. ignore_reinit_error=True,
  174. address=self.address,
  175. _redis_username=self.redis_username,
  176. _redis_password=self.redis_password,
  177. )
  178. logger.info(output_info)
  179. self.connected = True
  180. def add_node(self, wait: bool = True, **node_args):
  181. """Adds a node to the local Ray Cluster.
  182. All nodes are by default started with the following settings:
  183. cleanup=True,
  184. num_cpus=1,
  185. object_store_memory=150 * 1024 * 1024 # 150 MiB
  186. Args:
  187. wait: Whether to wait until the node is alive.
  188. node_args: Keyword arguments used in `start_ray_head` and
  189. `start_ray_node`. Overrides defaults.
  190. Returns:
  191. Node object of the added Ray node.
  192. """
  193. default_kwargs = {
  194. "num_cpus": 1,
  195. "num_gpus": 0,
  196. "object_store_memory": 150 * 1024 * 1024, # 150 MiB
  197. "min_worker_port": 0,
  198. "max_worker_port": 0,
  199. }
  200. ray_params = ray._private.parameter.RayParams(**node_args)
  201. ray_params.update_if_absent(**default_kwargs)
  202. with disable_client_hook():
  203. if self.head_node is None:
  204. node = ray._private.node.Node(
  205. ray_params,
  206. head=True,
  207. shutdown_at_exit=self._shutdown_at_exit,
  208. spawn_reaper=self._shutdown_at_exit,
  209. )
  210. self.head_node = node
  211. self.redis_address = self.head_node.redis_address
  212. self.redis_username = node_args.get(
  213. "redis_username", ray_constants.REDIS_DEFAULT_USERNAME
  214. )
  215. self.redis_password = node_args.get(
  216. "redis_password", ray_constants.REDIS_DEFAULT_PASSWORD
  217. )
  218. self.webui_url = self.head_node.webui_url
  219. # Init global state accessor when creating head node.
  220. gcs_options = GcsClientOptions.create(
  221. node.gcs_address,
  222. None,
  223. allow_cluster_id_nil=True,
  224. fetch_cluster_id_if_nil=False,
  225. )
  226. self.global_state._initialize_global_state(gcs_options)
  227. # Write the Ray cluster address for convenience in unit
  228. # testing. ray.init() and ray.init(address="auto") will connect
  229. # to the local cluster.
  230. ray._private.utils.write_ray_address(self.head_node.gcs_address)
  231. else:
  232. ray_params.update_if_absent(redis_address=self.redis_address)
  233. ray_params.update_if_absent(gcs_address=self.gcs_address)
  234. # We only need one log monitor per physical node.
  235. ray_params.update_if_absent(include_log_monitor=False)
  236. # Let grpc pick a port.
  237. ray_params.update_if_absent(node_manager_port=0)
  238. if "dashboard_agent_listen_port" not in node_args:
  239. # Pick a random one to not conflict
  240. # with the head node dashboard agent
  241. ray_params.dashboard_agent_listen_port = None
  242. node = ray._private.node.Node(
  243. ray_params,
  244. head=False,
  245. shutdown_at_exit=self._shutdown_at_exit,
  246. spawn_reaper=self._shutdown_at_exit,
  247. )
  248. self.worker_nodes.add(node)
  249. if wait:
  250. # Wait for the node to appear in the client table. We do this
  251. # so that the nodes appears in the client table in the order
  252. # that the corresponding calls to add_node were made. We do
  253. # this because in the tests we assume that the driver is
  254. # connected to the first node that is added.
  255. self._wait_for_node(node)
  256. return node
  257. def remove_node(self, node, allow_graceful=True):
  258. """Kills all processes associated with worker node.
  259. Args:
  260. node: Worker node of which all associated processes
  261. will be removed.
  262. """
  263. global_node = ray._private.worker.global_worker.node
  264. if global_node is not None:
  265. if node._raylet_socket_name == global_node._raylet_socket_name:
  266. ray.shutdown()
  267. raise ValueError(
  268. "Removing a node that is connected to this Ray client "
  269. "is not allowed because it will break the driver. "
  270. "You can use the get_other_node utility to avoid removing "
  271. "a node that the Ray client is connected."
  272. )
  273. node.destroy_external_storage()
  274. if self.head_node == node:
  275. # We have to wait to prevent the raylet becomes a zombie which will prevent
  276. # worker from exiting
  277. self.head_node.kill_all_processes(
  278. check_alive=False, allow_graceful=allow_graceful, wait=True
  279. )
  280. self.head_node = None
  281. # TODO(rliaw): Do we need to kill all worker processes?
  282. else:
  283. # We have to wait to prevent the raylet becomes a zombie which will prevent
  284. # worker from exiting
  285. node.kill_all_processes(
  286. check_alive=False, allow_graceful=allow_graceful, wait=True
  287. )
  288. self.worker_nodes.remove(node)
  289. assert (
  290. not node.any_processes_alive()
  291. ), "There are zombie processes left over after killing."
  292. def _wait_for_node(self, node, timeout: float = 30):
  293. """Wait until this node has appeared in the client table.
  294. Args:
  295. node (ray._private.node.Node): The node to wait for.
  296. timeout: The amount of time in seconds to wait before raising an
  297. exception.
  298. Raises:
  299. TimeoutError: An exception is raised if the timeout expires before
  300. the node appears in the client table.
  301. """
  302. ray._private.services.wait_for_node(
  303. node.gcs_address,
  304. node.plasma_store_socket_name,
  305. timeout,
  306. )
  307. def wait_for_nodes(self, timeout: float = 30):
  308. """Waits for correct number of nodes to be registered.
  309. This will wait until the number of live nodes in the client table
  310. exactly matches the number of "add_node" calls minus the number of
  311. "remove_node" calls that have been made on this cluster. This means
  312. that if a node dies without "remove_node" having been called, this will
  313. raise an exception.
  314. Args:
  315. timeout: The number of seconds to wait for nodes to join
  316. before failing.
  317. Raises:
  318. TimeoutError: An exception is raised if we time out while waiting
  319. for nodes to join.
  320. """
  321. start_time = time.time()
  322. while time.time() - start_time < timeout:
  323. live_clients = self.global_state._live_node_ids()
  324. expected = len(self.list_all_nodes())
  325. if len(live_clients) == expected:
  326. logger.debug("All nodes registered as expected.")
  327. return
  328. else:
  329. logger.debug(
  330. f"{len(live_clients)} nodes are currently registered, "
  331. f"but we are expecting {expected}"
  332. )
  333. time.sleep(0.1)
  334. raise TimeoutError("Timed out while waiting for nodes to join.")
  335. def list_all_nodes(self):
  336. """Lists all nodes.
  337. TODO(rliaw): What is the desired behavior if a head node
  338. dies before worker nodes die?
  339. Returns:
  340. List of all nodes, including the head node.
  341. """
  342. nodes = list(self.worker_nodes)
  343. if self.head_node:
  344. nodes = [self.head_node] + nodes
  345. return nodes
  346. def remaining_processes_alive(self):
  347. """Returns a bool indicating whether all processes are alive or not.
  348. Note that this ignores processes that have been explicitly killed,
  349. e.g., via a command like node.kill_raylet().
  350. Returns:
  351. True if all processes are alive and false otherwise.
  352. """
  353. return all(node.remaining_processes_alive() for node in self.list_all_nodes())
  354. def shutdown(self):
  355. """Removes all nodes."""
  356. # We create a list here as a copy because `remove_node`
  357. # modifies `self.worker_nodes`.
  358. all_nodes = list(self.worker_nodes)
  359. for node in all_nodes:
  360. self.remove_node(node)
  361. if self.head_node is not None:
  362. self.remove_node(self.head_node)
  363. # need to reset internal kv since gcs is down
  364. ray.experimental.internal_kv._internal_kv_reset()
  365. # Delete the cluster address.
  366. ray._common.utils.reset_ray_address()