| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415 |
- import copy
- import json
- import logging
- import os
- import subprocess
- import tempfile
- import time
- from typing import Dict, Optional
- import yaml
- import ray
- import ray._private.services
- from ray._private import ray_constants
- from ray._private.client_mode_hook import disable_client_hook
- from ray._raylet import GcsClientOptions
- from ray.autoscaler._private.fake_multi_node.node_provider import FAKE_HEAD_NODE_ID
- from ray.util.annotations import DeveloperAPI
- logger = logging.getLogger(__name__)
- cluster_not_supported = os.name == "nt"
- @DeveloperAPI
- class AutoscalingCluster:
- """Create a local autoscaling cluster for testing.
- See test_autoscaler_fake_multinode.py for an end-to-end example.
- """
- def __init__(
- self,
- head_resources: dict,
- worker_node_types: dict,
- autoscaler_v2: bool = False,
- **config_kwargs,
- ):
- """Create the cluster.
- Args:
- head_resources: resources of the head node, including CPU.
- worker_node_types: autoscaler node types config for worker nodes.
- """
- self._head_resources = head_resources
- self._config = self._generate_config(
- head_resources,
- worker_node_types,
- autoscaler_v2=autoscaler_v2,
- **config_kwargs,
- )
- self._autoscaler_v2 = autoscaler_v2
- def _generate_config(
- self, head_resources, worker_node_types, autoscaler_v2=False, **config_kwargs
- ):
- base_config = yaml.safe_load(
- open(
- os.path.join(
- os.path.dirname(ray.__file__),
- "autoscaler/_private/fake_multi_node/example.yaml",
- )
- )
- )
- custom_config = copy.deepcopy(base_config)
- custom_config["available_node_types"] = worker_node_types
- custom_config["available_node_types"]["ray.head.default"] = {
- "resources": head_resources,
- "node_config": {},
- "max_workers": 0,
- }
- # Autoscaler v2 specific configs
- if autoscaler_v2:
- custom_config["provider"]["launch_multiple"] = True
- custom_config["provider"]["head_node_id"] = FAKE_HEAD_NODE_ID
- custom_config.update(config_kwargs)
- return custom_config
- def start(self, _system_config=None, override_env: Optional[Dict] = None):
- """Start the cluster.
- After this call returns, you can connect to the cluster with
- ray.init("auto").
- """
- subprocess.check_call(["ray", "stop", "--force"])
- _, fake_config = tempfile.mkstemp()
- with open(fake_config, "w") as f:
- f.write(json.dumps(self._config))
- cmd = [
- "ray",
- "start",
- "--autoscaling-config={}".format(fake_config),
- "--head",
- ]
- if "CPU" in self._head_resources:
- cmd.append("--num-cpus={}".format(self._head_resources.pop("CPU")))
- if "GPU" in self._head_resources:
- cmd.append("--num-gpus={}".format(self._head_resources.pop("GPU")))
- if "object_store_memory" in self._head_resources:
- cmd.append(
- "--object-store-memory={}".format(
- self._head_resources.pop("object_store_memory")
- )
- )
- if self._head_resources:
- cmd.append("--resources='{}'".format(json.dumps(self._head_resources)))
- if _system_config is not None:
- cmd.append(
- "--system-config={}".format(
- json.dumps(_system_config, separators=(",", ":"))
- )
- )
- env = os.environ.copy()
- env.update({"AUTOSCALER_UPDATE_INTERVAL_S": "1", "RAY_FAKE_CLUSTER": "1"})
- if self._autoscaler_v2:
- # Set the necessary environment variables for autoscaler v2.
- env.update(
- {
- "RAY_enable_autoscaler_v2": "1",
- "RAY_CLOUD_INSTANCE_ID": FAKE_HEAD_NODE_ID,
- "RAY_OVERRIDE_NODE_ID_FOR_TESTING": FAKE_HEAD_NODE_ID,
- }
- )
- if override_env:
- env.update(override_env)
- subprocess.check_call(cmd, env=env)
- def shutdown(self):
- """Terminate the cluster."""
- subprocess.check_call(["ray", "stop", "--force"])
- @DeveloperAPI
- class Cluster:
- def __init__(
- self,
- initialize_head: bool = False,
- connect: bool = False,
- head_node_args: dict = None,
- shutdown_at_exit: bool = True,
- ):
- """Initializes all services of a Ray cluster.
- Args:
- initialize_head: Automatically start a Ray cluster
- by initializing the head node. Defaults to False.
- connect: If `initialize_head=True` and `connect=True`,
- ray.init will be called with the address of this cluster
- passed in.
- head_node_args: Arguments to be passed into
- `start_ray_head` via `self.add_node`.
- shutdown_at_exit: If True, registers an exit hook
- for shutting down all started processes.
- """
- if cluster_not_supported:
- logger.warning(
- "Ray cluster mode is currently experimental and untested on "
- "Windows. If you are using it and running into issues please "
- "file a report at https://github.com/ray-project/ray/issues."
- )
- self.head_node = None
- self.worker_nodes = set()
- self.redis_address = None
- self.connected = False
- # Create a new global state accessor for fetching GCS table.
- self.global_state = ray._private.state.GlobalState()
- self._shutdown_at_exit = shutdown_at_exit
- if not initialize_head and connect:
- raise RuntimeError("Cannot connect to uninitialized cluster.")
- if initialize_head:
- head_node_args = head_node_args or {}
- self.add_node(**head_node_args)
- if connect:
- self.connect()
- @property
- def gcs_address(self):
- if self.head_node is None:
- return None
- return self.head_node.gcs_address
- @property
- def address(self):
- return self.gcs_address
- def connect(self, namespace=None):
- """Connect the driver to the cluster."""
- assert self.address is not None
- assert not self.connected
- output_info = ray.init(
- namespace=namespace,
- ignore_reinit_error=True,
- address=self.address,
- _redis_username=self.redis_username,
- _redis_password=self.redis_password,
- )
- logger.info(output_info)
- self.connected = True
- def add_node(self, wait: bool = True, **node_args):
- """Adds a node to the local Ray Cluster.
- All nodes are by default started with the following settings:
- cleanup=True,
- num_cpus=1,
- object_store_memory=150 * 1024 * 1024 # 150 MiB
- Args:
- wait: Whether to wait until the node is alive.
- node_args: Keyword arguments used in `start_ray_head` and
- `start_ray_node`. Overrides defaults.
- Returns:
- Node object of the added Ray node.
- """
- default_kwargs = {
- "num_cpus": 1,
- "num_gpus": 0,
- "object_store_memory": 150 * 1024 * 1024, # 150 MiB
- "min_worker_port": 0,
- "max_worker_port": 0,
- }
- ray_params = ray._private.parameter.RayParams(**node_args)
- ray_params.update_if_absent(**default_kwargs)
- with disable_client_hook():
- if self.head_node is None:
- node = ray._private.node.Node(
- ray_params,
- head=True,
- shutdown_at_exit=self._shutdown_at_exit,
- spawn_reaper=self._shutdown_at_exit,
- )
- self.head_node = node
- self.redis_address = self.head_node.redis_address
- self.redis_username = node_args.get(
- "redis_username", ray_constants.REDIS_DEFAULT_USERNAME
- )
- self.redis_password = node_args.get(
- "redis_password", ray_constants.REDIS_DEFAULT_PASSWORD
- )
- self.webui_url = self.head_node.webui_url
- # Init global state accessor when creating head node.
- gcs_options = GcsClientOptions.create(
- node.gcs_address,
- None,
- allow_cluster_id_nil=True,
- fetch_cluster_id_if_nil=False,
- )
- self.global_state._initialize_global_state(gcs_options)
- # Write the Ray cluster address for convenience in unit
- # testing. ray.init() and ray.init(address="auto") will connect
- # to the local cluster.
- ray._private.utils.write_ray_address(self.head_node.gcs_address)
- else:
- ray_params.update_if_absent(redis_address=self.redis_address)
- ray_params.update_if_absent(gcs_address=self.gcs_address)
- # We only need one log monitor per physical node.
- ray_params.update_if_absent(include_log_monitor=False)
- # Let grpc pick a port.
- ray_params.update_if_absent(node_manager_port=0)
- if "dashboard_agent_listen_port" not in node_args:
- # Pick a random one to not conflict
- # with the head node dashboard agent
- ray_params.dashboard_agent_listen_port = None
- node = ray._private.node.Node(
- ray_params,
- head=False,
- shutdown_at_exit=self._shutdown_at_exit,
- spawn_reaper=self._shutdown_at_exit,
- )
- self.worker_nodes.add(node)
- if wait:
- # Wait for the node to appear in the client table. We do this
- # so that the nodes appears in the client table in the order
- # that the corresponding calls to add_node were made. We do
- # this because in the tests we assume that the driver is
- # connected to the first node that is added.
- self._wait_for_node(node)
- return node
- def remove_node(self, node, allow_graceful=True):
- """Kills all processes associated with worker node.
- Args:
- node: Worker node of which all associated processes
- will be removed.
- """
- global_node = ray._private.worker.global_worker.node
- if global_node is not None:
- if node._raylet_socket_name == global_node._raylet_socket_name:
- ray.shutdown()
- raise ValueError(
- "Removing a node that is connected to this Ray client "
- "is not allowed because it will break the driver. "
- "You can use the get_other_node utility to avoid removing "
- "a node that the Ray client is connected."
- )
- node.destroy_external_storage()
- if self.head_node == node:
- # We have to wait to prevent the raylet becomes a zombie which will prevent
- # worker from exiting
- self.head_node.kill_all_processes(
- check_alive=False, allow_graceful=allow_graceful, wait=True
- )
- self.head_node = None
- # TODO(rliaw): Do we need to kill all worker processes?
- else:
- # We have to wait to prevent the raylet becomes a zombie which will prevent
- # worker from exiting
- node.kill_all_processes(
- check_alive=False, allow_graceful=allow_graceful, wait=True
- )
- self.worker_nodes.remove(node)
- assert (
- not node.any_processes_alive()
- ), "There are zombie processes left over after killing."
- def _wait_for_node(self, node, timeout: float = 30):
- """Wait until this node has appeared in the client table.
- Args:
- node (ray._private.node.Node): The node to wait for.
- timeout: The amount of time in seconds to wait before raising an
- exception.
- Raises:
- TimeoutError: An exception is raised if the timeout expires before
- the node appears in the client table.
- """
- ray._private.services.wait_for_node(
- node.gcs_address,
- node.plasma_store_socket_name,
- timeout,
- )
- def wait_for_nodes(self, timeout: float = 30):
- """Waits for correct number of nodes to be registered.
- This will wait until the number of live nodes in the client table
- exactly matches the number of "add_node" calls minus the number of
- "remove_node" calls that have been made on this cluster. This means
- that if a node dies without "remove_node" having been called, this will
- raise an exception.
- Args:
- timeout: The number of seconds to wait for nodes to join
- before failing.
- Raises:
- TimeoutError: An exception is raised if we time out while waiting
- for nodes to join.
- """
- start_time = time.time()
- while time.time() - start_time < timeout:
- live_clients = self.global_state._live_node_ids()
- expected = len(self.list_all_nodes())
- if len(live_clients) == expected:
- logger.debug("All nodes registered as expected.")
- return
- else:
- logger.debug(
- f"{len(live_clients)} nodes are currently registered, "
- f"but we are expecting {expected}"
- )
- time.sleep(0.1)
- raise TimeoutError("Timed out while waiting for nodes to join.")
- def list_all_nodes(self):
- """Lists all nodes.
- TODO(rliaw): What is the desired behavior if a head node
- dies before worker nodes die?
- Returns:
- List of all nodes, including the head node.
- """
- nodes = list(self.worker_nodes)
- if self.head_node:
- nodes = [self.head_node] + nodes
- return nodes
- def remaining_processes_alive(self):
- """Returns a bool indicating whether all processes are alive or not.
- Note that this ignores processes that have been explicitly killed,
- e.g., via a command like node.kill_raylet().
- Returns:
- True if all processes are alive and false otherwise.
- """
- return all(node.remaining_processes_alive() for node in self.list_all_nodes())
- def shutdown(self):
- """Removes all nodes."""
- # We create a list here as a copy because `remove_node`
- # modifies `self.worker_nodes`.
- all_nodes = list(self.worker_nodes)
- for node in all_nodes:
- self.remove_node(node)
- if self.head_node is not None:
- self.remove_node(self.head_node)
- # need to reset internal kv since gcs is down
- ray.experimental.internal_kv._internal_kv_reset()
- # Delete the cluster address.
- ray._common.utils.reset_ray_address()
|