| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- import time
- from collections import Counter
- from typing import List, NamedTuple
- from ray._raylet import GcsClient
- from ray.autoscaler.v2.schema import ClusterStatus, Stats
- from ray.autoscaler.v2.utils import ClusterStatusParser
- from ray.core.generated.autoscaler_pb2 import (
- ClusterResourceState,
- GetClusterResourceStateReply,
- GetClusterStatusReply,
- )
- DEFAULT_RPC_TIMEOUT_S = 10
- class ResourceRequest(NamedTuple):
- resources: dict
- label_selector: dict
- def request_cluster_resources(
- gcs_address: str,
- to_request: List[dict],
- timeout: int = DEFAULT_RPC_TIMEOUT_S,
- ):
- """Request resources from the autoscaler.
- This will add a cluster resource constraint to GCS. GCS will asynchronously
- pass the constraint to the autoscaler, and the autoscaler will try to provision the
- requested minimal bundles in `to_request`.
- If the cluster already has `to_request` resources, this will be an no-op.
- Future requests submitted through this API will overwrite the previous requests.
- Args:
- gcs_address: The GCS address to query.
- to_request: A list of resource requests to request the cluster to have.
- Each resource request is a tuple of resources and a label_selector
- to apply per-bundle. e.g.: [{"resources": {"CPU": 1, "GPU": 1}, "label_selector": {"accelerator-type": "A100"}}]
- timeout: Timeout in seconds for the request to be timeout
- """
- assert len(gcs_address) > 0, "GCS address is not specified."
- # Convert bundle dicts to ResourceRequest tuples.
- normalized: List[ResourceRequest] = []
- for r in to_request:
- assert isinstance(
- r, dict
- ), f"Internal Error: Expected a dict, but got {type(r)}"
- resources = r.get("resources", {})
- selector = r.get("label_selector", {})
- normalized.append(ResourceRequest(resources, selector))
- to_request = normalized
- # Aggregate bundle by shape
- def keyfunc(r):
- return (
- frozenset(r.resources.items()),
- frozenset(r.label_selector.items()),
- )
- grouped_requests = Counter(keyfunc(r) for r in to_request)
- bundles: List[dict] = []
- label_selectors: List[dict] = []
- counts: List[int] = []
- for (bundle, selector), count in grouped_requests.items():
- bundles.append(dict(bundle))
- label_selectors.append(dict(selector))
- counts.append(count)
- GcsClient(gcs_address).request_cluster_resource_constraint(
- bundles, label_selectors, counts, timeout_s=timeout
- )
- def get_cluster_status(
- gcs_address: str, timeout: int = DEFAULT_RPC_TIMEOUT_S
- ) -> ClusterStatus:
- """
- Get the cluster status from the autoscaler.
- Args:
- gcs_address: The GCS address to query.
- timeout: Timeout in seconds for the request to be timeout
- Returns:
- A ClusterStatus object.
- """
- assert len(gcs_address) > 0, "GCS address is not specified."
- req_time = time.time()
- str_reply = GcsClient(gcs_address).get_cluster_status(timeout_s=timeout)
- reply_time = time.time()
- reply = GetClusterStatusReply()
- reply.ParseFromString(str_reply)
- # TODO(rickyx): To be more accurate, we could add a timestamp field from the reply.
- return ClusterStatusParser.from_get_cluster_status_reply(
- reply,
- stats=Stats(gcs_request_time_s=reply_time - req_time, request_ts_s=req_time),
- )
- def get_cluster_resource_state(gcs_client: GcsClient) -> ClusterResourceState:
- """
- Get the cluster resource state from GCS.
- Args:
- gcs_client: The GCS client to query.
- Returns:
- A ClusterResourceState object
- Raises:
- Exception: If the request times out or failed.
- """
- str_reply = gcs_client.get_cluster_resource_state()
- reply = GetClusterResourceStateReply()
- reply.ParseFromString(str_reply)
- return reply.cluster_resource_state
|