batching_node_provider.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. import logging
  2. from collections import defaultdict
  3. from dataclasses import dataclass, field
  4. from typing import Any, Dict, List, Optional, Set
  5. from ray.autoscaler._private.constants import (
  6. DISABLE_LAUNCH_CONFIG_CHECK_KEY,
  7. DISABLE_NODE_UPDATERS_KEY,
  8. FOREGROUND_NODE_LAUNCH_KEY,
  9. )
  10. from ray.autoscaler._private.util import NodeID, NodeIP, NodeKind, NodeStatus, NodeType
  11. from ray.autoscaler.node_provider import NodeProvider
  12. from ray.autoscaler.tags import (
  13. NODE_KIND_HEAD,
  14. TAG_RAY_NODE_KIND,
  15. TAG_RAY_NODE_STATUS,
  16. TAG_RAY_REPLICA_INDEX,
  17. TAG_RAY_USER_NODE_TYPE,
  18. )
  19. logger = logging.getLogger(__name__)
  20. @dataclass
  21. class ScaleRequest:
  22. """Stores desired scale computed by the autoscaler.
  23. Attributes:
  24. desired_num_workers: Map of worker NodeType to desired number of workers of
  25. that type.
  26. workers_to_delete: List of ids of nodes that should be removed.
  27. """
  28. desired_num_workers: Dict[NodeType, int] = field(default_factory=dict)
  29. workers_to_delete: Set[NodeID] = field(default_factory=set)
  30. @dataclass
  31. class NodeData:
  32. """Stores all data about a Ray node needed by the autoscaler.
  33. Attributes:
  34. kind: Whether the node is the head or a worker.
  35. type: The user-defined type of the node.
  36. replica_index: An identifier for nodes in a replica of a TPU worker group.
  37. This value is set as a Pod label by a GKE webhook when TPUs are requested
  38. ip: Cluster-internal ip of the node. ip can be None if the ip
  39. has not yet been assigned.
  40. status: The status of the node. You must adhere to the following semantics
  41. for status:
  42. * The status must be "up-to-date" if and only if the node is running.
  43. * The status must be "update-failed" if and only if the node is in an
  44. unknown or failed state.
  45. * If the node is in a pending (starting-up) state, the status should be
  46. a brief user-facing description of why the node is pending.
  47. """
  48. kind: NodeKind
  49. type: NodeType
  50. ip: Optional[NodeIP]
  51. status: NodeStatus
  52. replica_index: Optional[str] = None
  53. class BatchingNodeProvider(NodeProvider):
  54. """Abstract subclass of NodeProvider meant for use with external cluster managers.
  55. Batches reads of cluster state into a single method, get_node_data, called at the
  56. start of an autoscaling update.
  57. Batches modifications to cluster state into a single method, submit_scale_request,
  58. called at the end of an autoscaling update.
  59. Implementing a concrete subclass of BatchingNodeProvider only requires overriding
  60. get_node_data() and submit_scale_request().
  61. See the method docstrings for more information.
  62. Note that an autoscaling update may be conditionally
  63. cancelled using the optional method safe_to_scale()
  64. of the root NodeProvider.
  65. """
  66. def __init__(
  67. self,
  68. provider_config: Dict[str, Any],
  69. cluster_name: str,
  70. ) -> None:
  71. NodeProvider.__init__(self, provider_config, cluster_name)
  72. self.node_data_dict: Dict[NodeID, NodeData] = {}
  73. # These flags enforce correct behavior for single-threaded node providers
  74. # which interact with external cluster managers:
  75. assert (
  76. provider_config.get(DISABLE_NODE_UPDATERS_KEY, False) is True
  77. ), f"To use BatchingNodeProvider, must set `{DISABLE_NODE_UPDATERS_KEY}:True`."
  78. assert provider_config.get(DISABLE_LAUNCH_CONFIG_CHECK_KEY, False) is True, (
  79. "To use BatchingNodeProvider, must set "
  80. f"`{DISABLE_LAUNCH_CONFIG_CHECK_KEY}:True`."
  81. )
  82. assert (
  83. provider_config.get(FOREGROUND_NODE_LAUNCH_KEY, False) is True
  84. ), f"To use BatchingNodeProvider, must set `{FOREGROUND_NODE_LAUNCH_KEY}:True`."
  85. # self.scale_change_needed tracks whether we need to update scale.
  86. # set to True in create_node and terminate_nodes calls
  87. # reset to False in non_terminated_nodes, which occurs at the start of the
  88. # autoscaling update. For good measure, also set to false in post_process.
  89. self.scale_change_needed = False
  90. self.scale_request = ScaleRequest()
  91. # Initialize map of replica indices to nodes in that replica
  92. self.replica_index_to_nodes = defaultdict(list[str])
  93. def get_node_data(self) -> Dict[NodeID, NodeData]:
  94. """Queries cluster manager for node info. Returns a mapping from node id to
  95. NodeData.
  96. Each NodeData value must adhere to the semantics of the NodeData docstring.
  97. (Note in particular the requirements for NodeData.status.)
  98. Consistency requirement:
  99. If a node id was present in ScaleRequest.workers_to_delete of a previously
  100. submitted scale request, it should no longer be present as a key in
  101. get_node_data.
  102. (Node termination must be registered immediately when submit_scale_request
  103. returns.)
  104. """
  105. raise NotImplementedError
  106. def submit_scale_request(self, scale_request: ScaleRequest) -> None:
  107. """Tells the cluster manager which nodes to delete and how many nodes of
  108. each node type to maintain.
  109. Consistency requirement:
  110. If a node id was present in ScaleRequest.workers_to_delete of a previously
  111. submitted scale request, it should no longer be present as key in get_node_data.
  112. (Node termination must be registered immediately when submit_scale_request
  113. returns.)
  114. """
  115. raise NotImplementedError
  116. def post_process(self) -> None:
  117. """Submit a scale request if it is necessary to do so."""
  118. if self.scale_change_needed:
  119. self.submit_scale_request(self.scale_request)
  120. self.scale_change_needed = False
  121. def non_terminated_nodes(self, tag_filters: Dict[str, str]) -> List[str]:
  122. self.scale_change_needed = False
  123. self.node_data_dict = self.get_node_data()
  124. # Initialize ScaleRequest
  125. self.scale_request = ScaleRequest(
  126. desired_num_workers=self.cur_num_workers(), # Current scale
  127. workers_to_delete=set(), # No workers to delete yet
  128. )
  129. all_nodes = list(self.node_data_dict.keys())
  130. self.replica_index_to_nodes.clear()
  131. for node_id in all_nodes:
  132. replica_index = self.node_data_dict[node_id].replica_index
  133. # Only add node to map if it belongs to a multi-host podslice
  134. if replica_index is not None:
  135. self.replica_index_to_nodes[replica_index].append(node_id)
  136. # Support filtering by TAG_RAY_NODE_KIND, TAG_RAY_NODE_STATUS, and
  137. # TAG_RAY_USER_NODE_TYPE.
  138. # The autoscaler only uses tag_filters={},
  139. # but filtering by the these keys is useful for testing.
  140. filtered_nodes = [
  141. node
  142. for node in all_nodes
  143. if tag_filters.items() <= self.node_tags(node).items()
  144. ]
  145. return filtered_nodes
  146. def cur_num_workers(self):
  147. """Returns dict mapping node type to the number of nodes of that type."""
  148. # Factor like this for convenient re-use.
  149. return self._cur_num_workers(self.node_data_dict)
  150. def _cur_num_workers(self, node_data_dict: Dict[str, Any]):
  151. num_workers_dict = defaultdict(int)
  152. for node_data in node_data_dict.values():
  153. if node_data.kind == NODE_KIND_HEAD:
  154. # Only track workers.
  155. continue
  156. num_workers_dict[node_data.type] += 1
  157. return num_workers_dict
  158. def node_tags(self, node_id: str) -> Dict[str, str]:
  159. node_data = self.node_data_dict[node_id]
  160. tags = {
  161. TAG_RAY_NODE_KIND: node_data.kind,
  162. TAG_RAY_NODE_STATUS: node_data.status,
  163. TAG_RAY_USER_NODE_TYPE: node_data.type,
  164. }
  165. if node_data.replica_index is not None:
  166. tags[TAG_RAY_REPLICA_INDEX] = node_data.replica_index
  167. return tags
  168. def internal_ip(self, node_id: str) -> str:
  169. return self.node_data_dict[node_id].ip
  170. def create_node(
  171. self, node_config: Dict[str, Any], tags: Dict[str, str], count: int
  172. ) -> Optional[Dict[str, Any]]:
  173. node_type = tags[TAG_RAY_USER_NODE_TYPE]
  174. self.scale_request.desired_num_workers[node_type] += count
  175. self.scale_change_needed = True
  176. def terminate_node(self, node_id: str) -> Optional[Dict[str, Any]]:
  177. # Sanity check: We should never try to delete the same node twice.
  178. if node_id in self.scale_request.workers_to_delete:
  179. logger.warning(
  180. f"Autoscaler tried to terminate node {node_id} twice in the same update"
  181. ". Skipping termination request."
  182. )
  183. return
  184. # Sanity check: We should never try to delete a node we haven't seen.
  185. if node_id not in self.node_data_dict:
  186. logger.warning(
  187. f"Autoscaler tried to terminate unkown node {node_id}"
  188. ". Skipping termination request."
  189. )
  190. return
  191. node_type = self.node_data_dict[node_id].type
  192. # Sanity check: Don't request less than 0 nodes.
  193. if self.scale_request.desired_num_workers[node_type] <= 0:
  194. # This is logically impossible.
  195. raise AssertionError(
  196. "NodeProvider attempted to request less than 0 workers of type "
  197. f"{node_type}. Skipping termination request."
  198. )
  199. # Terminate node
  200. self.scale_request.desired_num_workers[node_type] -= 1
  201. self.scale_request.workers_to_delete.add(node_id)
  202. # Scale down all nodes in replica if node_id is part of a multi-host podslice
  203. tags = self.node_tags(node_id)
  204. if TAG_RAY_REPLICA_INDEX in tags:
  205. node_replica_index = tags[TAG_RAY_REPLICA_INDEX]
  206. for worker_id in self.replica_index_to_nodes[node_replica_index]:
  207. # Check if worker has already been scheduled to delete
  208. if worker_id not in self.scale_request.workers_to_delete:
  209. self.scale_request.workers_to_delete.add(worker_id)
  210. logger.info(
  211. f"Autoscaler terminating node {worker_id} "
  212. f"in multi-host replica {node_replica_index}."
  213. )
  214. self.scale_change_needed = True