sdk.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import time
  2. from collections import Counter
  3. from typing import List, NamedTuple
  4. from ray._raylet import GcsClient
  5. from ray.autoscaler.v2.schema import ClusterStatus, Stats
  6. from ray.autoscaler.v2.utils import ClusterStatusParser
  7. from ray.core.generated.autoscaler_pb2 import (
  8. ClusterResourceState,
  9. GetClusterResourceStateReply,
  10. GetClusterStatusReply,
  11. )
  12. DEFAULT_RPC_TIMEOUT_S = 10
  13. class ResourceRequest(NamedTuple):
  14. resources: dict
  15. label_selector: dict
  16. def request_cluster_resources(
  17. gcs_address: str,
  18. to_request: List[dict],
  19. timeout: int = DEFAULT_RPC_TIMEOUT_S,
  20. ):
  21. """Request resources from the autoscaler.
  22. This will add a cluster resource constraint to GCS. GCS will asynchronously
  23. pass the constraint to the autoscaler, and the autoscaler will try to provision the
  24. requested minimal bundles in `to_request`.
  25. If the cluster already has `to_request` resources, this will be an no-op.
  26. Future requests submitted through this API will overwrite the previous requests.
  27. Args:
  28. gcs_address: The GCS address to query.
  29. to_request: A list of resource requests to request the cluster to have.
  30. Each resource request is a tuple of resources and a label_selector
  31. to apply per-bundle. e.g.: [{"resources": {"CPU": 1, "GPU": 1}, "label_selector": {"accelerator-type": "A100"}}]
  32. timeout: Timeout in seconds for the request to be timeout
  33. """
  34. assert len(gcs_address) > 0, "GCS address is not specified."
  35. # Convert bundle dicts to ResourceRequest tuples.
  36. normalized: List[ResourceRequest] = []
  37. for r in to_request:
  38. assert isinstance(
  39. r, dict
  40. ), f"Internal Error: Expected a dict, but got {type(r)}"
  41. resources = r.get("resources", {})
  42. selector = r.get("label_selector", {})
  43. normalized.append(ResourceRequest(resources, selector))
  44. to_request = normalized
  45. # Aggregate bundle by shape
  46. def keyfunc(r):
  47. return (
  48. frozenset(r.resources.items()),
  49. frozenset(r.label_selector.items()),
  50. )
  51. grouped_requests = Counter(keyfunc(r) for r in to_request)
  52. bundles: List[dict] = []
  53. label_selectors: List[dict] = []
  54. counts: List[int] = []
  55. for (bundle, selector), count in grouped_requests.items():
  56. bundles.append(dict(bundle))
  57. label_selectors.append(dict(selector))
  58. counts.append(count)
  59. GcsClient(gcs_address).request_cluster_resource_constraint(
  60. bundles, label_selectors, counts, timeout_s=timeout
  61. )
  62. def get_cluster_status(
  63. gcs_address: str, timeout: int = DEFAULT_RPC_TIMEOUT_S
  64. ) -> ClusterStatus:
  65. """
  66. Get the cluster status from the autoscaler.
  67. Args:
  68. gcs_address: The GCS address to query.
  69. timeout: Timeout in seconds for the request to be timeout
  70. Returns:
  71. A ClusterStatus object.
  72. """
  73. assert len(gcs_address) > 0, "GCS address is not specified."
  74. req_time = time.time()
  75. str_reply = GcsClient(gcs_address).get_cluster_status(timeout_s=timeout)
  76. reply_time = time.time()
  77. reply = GetClusterStatusReply()
  78. reply.ParseFromString(str_reply)
  79. # TODO(rickyx): To be more accurate, we could add a timestamp field from the reply.
  80. return ClusterStatusParser.from_get_cluster_status_reply(
  81. reply,
  82. stats=Stats(gcs_request_time_s=reply_time - req_time, request_ts_s=req_time),
  83. )
  84. def get_cluster_resource_state(gcs_client: GcsClient) -> ClusterResourceState:
  85. """
  86. Get the cluster resource state from GCS.
  87. Args:
  88. gcs_client: The GCS client to query.
  89. Returns:
  90. A ClusterResourceState object
  91. Raises:
  92. Exception: If the request times out or failed.
  93. """
  94. str_reply = gcs_client.get_cluster_resource_state()
  95. reply = GetClusterResourceStateReply()
  96. reply.ParseFromString(str_reply)
  97. return reply.cluster_resource_state