node.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856
  1. """Abstractions around GCP resources and nodes.
  2. The logic has been abstracted away here to allow for different GCP resources
  3. (API endpoints), which can differ widely, making it impossible to use
  4. the same logic for everything.
  5. Classes inheriting from ``GCPResource`` represent different GCP resources -
  6. API endpoints that allow for nodes to be created, removed, listed and
  7. otherwise managed. Those classes contain methods abstracting GCP REST API
  8. calls.
  9. Each resource has a corresponding node type, represented by a
  10. class inheriting from ``GCPNode``. Those classes are essentially dicts
  11. with some extra methods. The instances of those classes will be created
  12. from API responses.
  13. The ``GCPNodeType`` enum is a lightweight way to classify nodes.
  14. Currently, Compute and TPU resources & nodes are supported.
  15. In order to add support for new resources, create classes inheriting from
  16. ``GCPResource`` and ``GCPNode``, update the ``GCPNodeType`` enum,
  17. update the ``_generate_node_name`` method and finally update the
  18. node provider.
  19. """
  20. import abc
  21. import logging
  22. import re
  23. import time
  24. from collections import UserDict
  25. from copy import deepcopy
  26. from enum import Enum
  27. from functools import wraps
  28. from typing import Any, Dict, List, Optional, Tuple, Union
  29. from uuid import uuid4
  30. import httplib2
  31. from google_auth_httplib2 import AuthorizedHttp
  32. from googleapiclient.discovery import Resource
  33. from googleapiclient.errors import HttpError
  34. from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
  35. logger = logging.getLogger(__name__)
  36. INSTANCE_NAME_MAX_LEN = 64
  37. INSTANCE_NAME_UUID_LEN = 8
  38. MAX_POLLS = 12
  39. # TPUs take a long while to respond, so we increase the MAX_POLLS
  40. # considerably - this probably could be smaller
  41. # TPU deletion uses MAX_POLLS
  42. MAX_POLLS_TPU = MAX_POLLS * 8
  43. POLL_INTERVAL = 5
  44. def _retry_on_exception(
  45. exception: Union[Exception, Tuple[Exception]],
  46. regex: Optional[str] = None,
  47. max_retries: int = MAX_POLLS,
  48. retry_interval_s: int = POLL_INTERVAL,
  49. ):
  50. """Retry a function call n-times for as long as it throws an exception."""
  51. def dec(func):
  52. @wraps(func)
  53. def wrapper(*args, **kwargs):
  54. def try_catch_exc():
  55. try:
  56. value = func(*args, **kwargs)
  57. return value
  58. except Exception as e:
  59. if not isinstance(e, exception) or (
  60. regex and not re.search(regex, str(e))
  61. ):
  62. raise e
  63. return e
  64. for _ in range(max_retries):
  65. ret = try_catch_exc()
  66. if not isinstance(ret, Exception):
  67. break
  68. time.sleep(retry_interval_s)
  69. if isinstance(ret, Exception):
  70. raise ret
  71. return ret
  72. return wrapper
  73. return dec
  74. def _generate_node_name(labels: dict, node_suffix: str) -> str:
  75. """Generate node name from labels and suffix.
  76. This is required so that the correct resource can be selected
  77. when the only information autoscaler has is the name of the node.
  78. The suffix is expected to be one of 'compute' or 'tpu'
  79. (as in ``GCPNodeType``).
  80. """
  81. name_label = labels[TAG_RAY_NODE_NAME]
  82. assert len(name_label) <= (INSTANCE_NAME_MAX_LEN - INSTANCE_NAME_UUID_LEN - 1), (
  83. name_label,
  84. len(name_label),
  85. )
  86. return f"{name_label}-{uuid4().hex[:INSTANCE_NAME_UUID_LEN]}-{node_suffix}"
  87. class GCPNodeType(Enum):
  88. """Enum for GCP node types (compute & tpu)"""
  89. COMPUTE = "compute"
  90. TPU = "tpu"
  91. @staticmethod
  92. def from_gcp_node(node: "GCPNode"):
  93. """Return GCPNodeType based on ``node``'s class"""
  94. if isinstance(node, GCPTPUNode):
  95. return GCPNodeType.TPU
  96. if isinstance(node, GCPComputeNode):
  97. return GCPNodeType.COMPUTE
  98. raise TypeError(f"Wrong GCPNode type {type(node)}.")
  99. @staticmethod
  100. def name_to_type(name: str):
  101. """Provided a node name, determine the type.
  102. This expects the name to be in format '[NAME]-[UUID]-[TYPE]',
  103. where [TYPE] is either 'compute' or 'tpu'.
  104. """
  105. return GCPNodeType(name.split("-")[-1])
  106. class GCPNode(UserDict, metaclass=abc.ABCMeta):
  107. """Abstraction around compute and tpu nodes"""
  108. NON_TERMINATED_STATUSES = None
  109. RUNNING_STATUSES = None
  110. STATUS_FIELD = None
  111. def __init__(self, base_dict: dict, resource: "GCPResource", **kwargs) -> None:
  112. super().__init__(base_dict, **kwargs)
  113. self.resource = resource
  114. assert isinstance(self.resource, GCPResource)
  115. def is_running(self) -> bool:
  116. return self.get(self.STATUS_FIELD) in self.RUNNING_STATUSES
  117. def is_terminated(self) -> bool:
  118. return self.get(self.STATUS_FIELD) not in self.NON_TERMINATED_STATUSES
  119. @abc.abstractmethod
  120. def get_labels(self) -> dict:
  121. return
  122. @abc.abstractmethod
  123. def get_external_ip(self) -> str:
  124. return
  125. @abc.abstractmethod
  126. def get_internal_ip(self) -> str:
  127. return
  128. def __repr__(self) -> str:
  129. return f"<{self.__class__.__name__}: {self.get('name')}>"
  130. class GCPComputeNode(GCPNode):
  131. """Abstraction around compute nodes"""
  132. # https://cloud.google.com/compute/docs/instances/instance-life-cycle
  133. NON_TERMINATED_STATUSES = {"PROVISIONING", "STAGING", "RUNNING"}
  134. TERMINATED_STATUSES = {"TERMINATED", "SUSPENDED"}
  135. RUNNING_STATUSES = {"RUNNING"}
  136. STATUS_FIELD = "status"
  137. def get_labels(self) -> dict:
  138. return self.get("labels", {})
  139. def get_external_ip(self) -> str:
  140. return (
  141. self.get("networkInterfaces", [{}])[0]
  142. .get("accessConfigs", [{}])[0]
  143. .get("natIP", None)
  144. )
  145. def get_internal_ip(self) -> str:
  146. return self.get("networkInterfaces", [{}])[0].get("networkIP")
  147. class GCPTPUNode(GCPNode):
  148. """Abstraction around tpu nodes"""
  149. # https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#State
  150. NON_TERMINATED_STATUSES = {"CREATING", "STARTING", "RESTARTING", "READY"}
  151. RUNNING_STATUSES = {"READY"}
  152. STATUS_FIELD = "state"
  153. def get_labels(self) -> dict:
  154. return self.get("labels", {})
  155. @property
  156. def num_workers(self) -> int:
  157. return len(self.get("networkEndpoints", [{}]))
  158. def get_external_ips(self) -> List[str]:
  159. return self.get("networkEndpoints", [{}])
  160. def get_external_ip(self, worker_index: int = 0) -> str:
  161. return (
  162. self.get_external_ips()[worker_index]
  163. .get("accessConfig", {})
  164. .get("externalIp", None)
  165. )
  166. def get_internal_ips(self) -> List[str]:
  167. return self.get("networkEndpoints", [{}])
  168. def get_internal_ip(self, worker_index: int = 0) -> str:
  169. return self.get_internal_ips()[worker_index].get("ipAddress", None)
  170. class GCPResource(metaclass=abc.ABCMeta):
  171. """Abstraction around compute and TPU resources"""
  172. def __init__(
  173. self,
  174. resource: Resource,
  175. project_id: str,
  176. availability_zone: str,
  177. cluster_name: str,
  178. ) -> None:
  179. self.resource = resource
  180. self.project_id = project_id
  181. self.availability_zone = availability_zone
  182. self.cluster_name = cluster_name
  183. @abc.abstractmethod
  184. def get_new_authorized_http(self, http: AuthorizedHttp) -> AuthorizedHttp:
  185. """Generate a new AuthorizedHttp object with the given credentials."""
  186. return
  187. @abc.abstractmethod
  188. def wait_for_operation(
  189. self,
  190. operation: dict,
  191. max_polls: int = MAX_POLLS,
  192. poll_interval: int = POLL_INTERVAL,
  193. ) -> dict:
  194. """Waits a preset amount of time for operation to complete."""
  195. return None
  196. @abc.abstractmethod
  197. def list_instances(
  198. self,
  199. label_filters: Optional[dict] = None,
  200. is_terminated: bool = False,
  201. ) -> List["GCPNode"]:
  202. """Returns a filtered list of all instances.
  203. The filter removes all terminated instances and, if ``label_filters``
  204. are provided, all instances which labels are not matching the
  205. ones provided.
  206. """
  207. return
  208. @abc.abstractmethod
  209. def get_instance(self, node_id: str) -> "GCPNode":
  210. """Returns a single instance."""
  211. return
  212. @abc.abstractmethod
  213. def set_labels(
  214. self, node: GCPNode, labels: dict, wait_for_operation: bool = True
  215. ) -> dict:
  216. """Sets labels on an instance and returns result.
  217. Completely replaces the labels dictionary."""
  218. return
  219. @abc.abstractmethod
  220. def create_instance(
  221. self, base_config: dict, labels: dict, wait_for_operation: bool = True
  222. ) -> Tuple[dict, str]:
  223. """Creates a single instance and returns result.
  224. Returns a tuple of (result, node_name).
  225. """
  226. return
  227. def create_instances(
  228. self,
  229. base_config: dict,
  230. labels: dict,
  231. count: int,
  232. wait_for_operation: bool = True,
  233. ) -> List[Tuple[dict, str]]:
  234. """Creates multiple instances and returns result.
  235. Returns a list of tuples of (result, node_name).
  236. """
  237. operations = [
  238. self.create_instance(base_config, labels, wait_for_operation=False)
  239. for i in range(count)
  240. ]
  241. if wait_for_operation:
  242. results = [
  243. (self.wait_for_operation(operation), node_name)
  244. for operation, node_name in operations
  245. ]
  246. else:
  247. results = operations
  248. return results
  249. @abc.abstractmethod
  250. def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
  251. """Deletes an instance and returns result."""
  252. return
  253. @abc.abstractmethod
  254. def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
  255. """Deletes an instance and returns result."""
  256. return
  257. @abc.abstractmethod
  258. def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
  259. """Starts a single instance and returns result."""
  260. return
  261. class GCPCompute(GCPResource):
  262. """Abstraction around GCP compute resource"""
  263. def get_new_authorized_http(self, http: AuthorizedHttp) -> AuthorizedHttp:
  264. """Generate a new AuthorizedHttp object with the given credentials."""
  265. new_http = AuthorizedHttp(http.credentials, http=httplib2.Http())
  266. return new_http
  267. def wait_for_operation(
  268. self,
  269. operation: dict,
  270. max_polls: int = MAX_POLLS,
  271. poll_interval: int = POLL_INTERVAL,
  272. ) -> dict:
  273. """Poll for compute zone operation until finished."""
  274. logger.info(
  275. "wait_for_compute_zone_operation: "
  276. f"Waiting for operation {operation['name']} to finish..."
  277. )
  278. for _ in range(max_polls):
  279. result = (
  280. self.resource.zoneOperations()
  281. .get(
  282. project=self.project_id,
  283. operation=operation["name"],
  284. zone=self.availability_zone,
  285. )
  286. .execute(http=self.get_new_authorized_http(self.resource._http))
  287. )
  288. if "error" in result:
  289. raise Exception(result["error"])
  290. if result["status"] == "DONE":
  291. logger.info(
  292. "wait_for_compute_zone_operation: "
  293. f"Operation {operation['name']} finished."
  294. )
  295. break
  296. time.sleep(poll_interval)
  297. return result
  298. def list_instances(
  299. self,
  300. label_filters: Optional[dict] = None,
  301. is_terminated: bool = False,
  302. ) -> List[GCPComputeNode]:
  303. label_filters = label_filters or {}
  304. if label_filters:
  305. label_filter_expr = (
  306. "("
  307. + " AND ".join(
  308. [
  309. "(labels.{key} = {value})".format(key=key, value=value)
  310. for key, value in label_filters.items()
  311. ]
  312. )
  313. + ")"
  314. )
  315. else:
  316. label_filter_expr = ""
  317. statuses = (
  318. GCPComputeNode.TERMINATED_STATUSES
  319. if is_terminated
  320. else GCPComputeNode.NON_TERMINATED_STATUSES
  321. )
  322. instance_state_filter_expr = (
  323. "("
  324. + " OR ".join(
  325. ["(status = {status})".format(status=status) for status in statuses]
  326. )
  327. + ")"
  328. )
  329. cluster_name_filter_expr = "(labels.{key} = {value})".format(
  330. key=TAG_RAY_CLUSTER_NAME, value=self.cluster_name
  331. )
  332. # TPU VMs spawn accompanying Compute Instances that must be filtered out,
  333. # else this results in duplicated nodes.
  334. tpu_negation_filter_expr = "(NOT labels.{label}:*)".format(label="tpu_cores")
  335. not_empty_filters = [
  336. f
  337. for f in [
  338. label_filter_expr,
  339. instance_state_filter_expr,
  340. cluster_name_filter_expr,
  341. tpu_negation_filter_expr,
  342. ]
  343. if f
  344. ]
  345. filter_expr = " AND ".join(not_empty_filters)
  346. response = (
  347. self.resource.instances()
  348. .list(
  349. project=self.project_id,
  350. zone=self.availability_zone,
  351. filter=filter_expr,
  352. )
  353. .execute(http=self.get_new_authorized_http(self.resource._http))
  354. )
  355. instances = response.get("items", [])
  356. return [GCPComputeNode(i, self) for i in instances]
  357. def get_instance(self, node_id: str) -> GCPComputeNode:
  358. instance = (
  359. self.resource.instances()
  360. .get(
  361. project=self.project_id,
  362. zone=self.availability_zone,
  363. instance=node_id,
  364. )
  365. .execute()
  366. )
  367. return GCPComputeNode(instance, self)
  368. def set_labels(
  369. self, node: GCPComputeNode, labels: dict, wait_for_operation: bool = True
  370. ) -> dict:
  371. body = {
  372. "labels": dict(node["labels"], **labels),
  373. "labelFingerprint": node["labelFingerprint"],
  374. }
  375. node_id = node["name"]
  376. operation = (
  377. self.resource.instances()
  378. .setLabels(
  379. project=self.project_id,
  380. zone=self.availability_zone,
  381. instance=node_id,
  382. body=body,
  383. )
  384. .execute(http=self.get_new_authorized_http(self.resource._http))
  385. )
  386. if wait_for_operation:
  387. result = self.wait_for_operation(operation)
  388. else:
  389. result = operation
  390. return result
  391. def _convert_resources_to_urls(
  392. self, configuration_dict: Dict[str, Any]
  393. ) -> Dict[str, Any]:
  394. """Ensures that resources are in their full URL form.
  395. GCP expects machineType and acceleratorType to be a full URL (e.g.
  396. `zones/us-west1/machineTypes/n1-standard-2`) instead of just the
  397. type (`n1-standard-2`)
  398. Args:
  399. configuration_dict: Dict of options that will be passed to GCP
  400. Returns:
  401. Input dictionary, but with possibly expanding `machineType` and
  402. `acceleratorType`.
  403. """
  404. configuration_dict = deepcopy(configuration_dict)
  405. existing_machine_type = configuration_dict["machineType"]
  406. if not re.search(".*/machineTypes/.*", existing_machine_type):
  407. configuration_dict[
  408. "machineType"
  409. ] = "zones/{zone}/machineTypes/{machine_type}".format(
  410. zone=self.availability_zone,
  411. machine_type=configuration_dict["machineType"],
  412. )
  413. for accelerator in configuration_dict.get("guestAccelerators", []):
  414. gpu_type = accelerator["acceleratorType"]
  415. if not re.search(".*/acceleratorTypes/.*", gpu_type):
  416. accelerator[
  417. "acceleratorType"
  418. ] = "projects/{project}/zones/{zone}/acceleratorTypes/{accelerator}".format( # noqa: E501
  419. project=self.project_id,
  420. zone=self.availability_zone,
  421. accelerator=gpu_type,
  422. )
  423. return configuration_dict
  424. def create_instance(
  425. self, base_config: dict, labels: dict, wait_for_operation: bool = True
  426. ) -> Tuple[dict, str]:
  427. config = self._convert_resources_to_urls(base_config)
  428. # removing TPU-specific default key set in config.py
  429. config.pop("networkConfig", None)
  430. name = _generate_node_name(labels, GCPNodeType.COMPUTE.value)
  431. labels = dict(config.get("labels", {}), **labels)
  432. config.update(
  433. {
  434. "labels": dict(labels, **{TAG_RAY_CLUSTER_NAME: self.cluster_name}),
  435. "name": name,
  436. }
  437. )
  438. # Allow Google Compute Engine instance templates.
  439. #
  440. # Config example:
  441. #
  442. # ...
  443. # node_config:
  444. # sourceInstanceTemplate: global/instanceTemplates/worker-16
  445. # machineType: e2-standard-16
  446. # ...
  447. #
  448. # node_config parameters override matching template parameters, if any.
  449. #
  450. # https://cloud.google.com/compute/docs/instance-templates
  451. # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert
  452. source_instance_template = config.pop("sourceInstanceTemplate", None)
  453. operation = (
  454. self.resource.instances()
  455. .insert(
  456. project=self.project_id,
  457. zone=self.availability_zone,
  458. sourceInstanceTemplate=source_instance_template,
  459. body=config,
  460. )
  461. .execute(http=self.get_new_authorized_http(self.resource._http))
  462. )
  463. if wait_for_operation:
  464. result = self.wait_for_operation(operation)
  465. else:
  466. result = operation
  467. return result, name
  468. def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
  469. operation = (
  470. self.resource.instances()
  471. .delete(
  472. project=self.project_id,
  473. zone=self.availability_zone,
  474. instance=node_id,
  475. )
  476. .execute()
  477. )
  478. if wait_for_operation:
  479. result = self.wait_for_operation(operation)
  480. else:
  481. result = operation
  482. return result
  483. def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
  484. operation = (
  485. self.resource.instances()
  486. .stop(
  487. project=self.project_id,
  488. zone=self.availability_zone,
  489. instance=node_id,
  490. )
  491. .execute()
  492. )
  493. if wait_for_operation:
  494. result = self.wait_for_operation(operation)
  495. else:
  496. result = operation
  497. return result
  498. def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
  499. operation = (
  500. self.resource.instances()
  501. .start(
  502. project=self.project_id,
  503. zone=self.availability_zone,
  504. instance=node_id,
  505. )
  506. .execute(http=self.get_new_authorized_http(self.resource._http))
  507. )
  508. if wait_for_operation:
  509. result = self.wait_for_operation(operation)
  510. else:
  511. result = operation
  512. return result
  513. class GCPTPU(GCPResource):
  514. """Abstraction around GCP TPU resource"""
  515. # node names already contain the path, but this is required for `parent`
  516. # arguments
  517. @property
  518. def path(self):
  519. return f"projects/{self.project_id}/locations/{self.availability_zone}"
  520. def get_new_authorized_http(self, http: AuthorizedHttp) -> AuthorizedHttp:
  521. """Generate a new AuthorizedHttp object with the given credentials."""
  522. new_http = AuthorizedHttp(http.credentials, http=httplib2.Http())
  523. return new_http
  524. def wait_for_operation(
  525. self,
  526. operation: dict,
  527. max_polls: int = MAX_POLLS_TPU,
  528. poll_interval: int = POLL_INTERVAL,
  529. ) -> dict:
  530. """Poll for TPU operation until finished."""
  531. logger.info(
  532. "wait_for_tpu_operation: "
  533. f"Waiting for operation {operation['name']} to finish..."
  534. )
  535. for _ in range(max_polls):
  536. result = (
  537. self.resource.projects()
  538. .locations()
  539. .operations()
  540. .get(name=f"{operation['name']}")
  541. .execute(http=self.get_new_authorized_http(self.resource._http))
  542. )
  543. if "error" in result:
  544. raise Exception(result["error"])
  545. if "response" in result:
  546. logger.info(
  547. "wait_for_tpu_operation: "
  548. f"Operation {operation['name']} finished."
  549. )
  550. break
  551. time.sleep(poll_interval)
  552. return result
  553. def list_instances(
  554. self,
  555. label_filters: Optional[dict] = None,
  556. is_terminated: bool = False,
  557. ) -> List[GCPTPUNode]:
  558. response = (
  559. self.resource.projects()
  560. .locations()
  561. .nodes()
  562. .list(parent=self.path)
  563. .execute(http=self.get_new_authorized_http(self.resource._http))
  564. )
  565. instances = response.get("nodes", [])
  566. instances = [GCPTPUNode(i, self) for i in instances]
  567. # filter_expr cannot be passed directly to API
  568. # so we need to filter the results ourselves
  569. # same logic as in GCPCompute.list_instances
  570. label_filters = label_filters or {}
  571. label_filters[TAG_RAY_CLUSTER_NAME] = self.cluster_name
  572. def filter_instance(instance: GCPTPUNode) -> bool:
  573. if instance.is_terminated():
  574. return False
  575. labels = instance.get_labels()
  576. if label_filters:
  577. for key, value in label_filters.items():
  578. if key not in labels:
  579. return False
  580. if value != labels[key]:
  581. return False
  582. return True
  583. instances = list(filter(filter_instance, instances))
  584. return instances
  585. def get_instance(self, node_id: str) -> GCPTPUNode:
  586. instance = (
  587. self.resource.projects()
  588. .locations()
  589. .nodes()
  590. .get(name=node_id)
  591. .execute(http=self.get_new_authorized_http(self.resource._http))
  592. )
  593. return GCPTPUNode(instance, self)
  594. # this sometimes fails without a clear reason, so we retry it
  595. # MAX_POLLS times
  596. @_retry_on_exception(HttpError, "unable to queue the operation")
  597. def set_labels(
  598. self, node: GCPTPUNode, labels: dict, wait_for_operation: bool = True
  599. ) -> dict:
  600. body = {
  601. "labels": dict(node["labels"], **labels),
  602. }
  603. update_mask = "labels"
  604. operation = (
  605. self.resource.projects()
  606. .locations()
  607. .nodes()
  608. .patch(
  609. name=node["name"],
  610. updateMask=update_mask,
  611. body=body,
  612. )
  613. .execute(http=self.get_new_authorized_http(self.resource._http))
  614. )
  615. if wait_for_operation:
  616. result = self.wait_for_operation(operation)
  617. else:
  618. result = operation
  619. return result
  620. def create_instance(
  621. self, base_config: dict, labels: dict, wait_for_operation: bool = True
  622. ) -> Tuple[dict, str]:
  623. config = base_config.copy()
  624. # removing Compute-specific default key set in config.py
  625. config.pop("networkInterfaces", None)
  626. name = _generate_node_name(labels, GCPNodeType.TPU.value)
  627. labels = dict(config.get("labels", {}), **labels)
  628. config.update(
  629. {
  630. "labels": dict(labels, **{TAG_RAY_CLUSTER_NAME: self.cluster_name}),
  631. }
  632. )
  633. if "networkConfig" not in config:
  634. config["networkConfig"] = {}
  635. if "enableExternalIps" not in config["networkConfig"]:
  636. # this is required for SSH to work, per google documentation
  637. # https://cloud.google.com/tpu/docs/users-guide-tpu-vm#create-curl
  638. config["networkConfig"]["enableExternalIps"] = True
  639. # replace serviceAccounts with serviceAccount, and scopes with scope
  640. # this is necessary for the head node to work
  641. # see here: https://tpu.googleapis.com/$discovery/rest?version=v2alpha1
  642. if "serviceAccounts" in config:
  643. config["serviceAccount"] = config.pop("serviceAccounts")[0]
  644. config["serviceAccount"]["scope"] = config["serviceAccount"].pop("scopes")
  645. operation = (
  646. self.resource.projects()
  647. .locations()
  648. .nodes()
  649. .create(
  650. parent=self.path,
  651. body=config,
  652. nodeId=name,
  653. )
  654. .execute(http=self.get_new_authorized_http(self.resource._http))
  655. )
  656. if wait_for_operation:
  657. result = self.wait_for_operation(operation)
  658. else:
  659. result = operation
  660. return result, name
  661. def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
  662. operation = (
  663. self.resource.projects()
  664. .locations()
  665. .nodes()
  666. .delete(name=node_id)
  667. .execute(http=self.get_new_authorized_http(self.resource._http))
  668. )
  669. # No need to increase MAX_POLLS for deletion
  670. if wait_for_operation:
  671. result = self.wait_for_operation(operation, max_polls=MAX_POLLS)
  672. else:
  673. result = operation
  674. return result
  675. def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
  676. operation = (
  677. self.resource.projects()
  678. .locations()
  679. .nodes()
  680. .stop(name=node_id)
  681. .execute(http=self.get_new_authorized_http(self.resource._http))
  682. )
  683. if wait_for_operation:
  684. result = self.wait_for_operation(operation, max_polls=MAX_POLLS)
  685. else:
  686. result = operation
  687. return result
  688. def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
  689. operation = (
  690. self.resource.projects()
  691. .locations()
  692. .nodes()
  693. .start(name=node_id)
  694. .execute(http=self.get_new_authorized_http(self.resource._http))
  695. )
  696. if wait_for_operation:
  697. result = self.wait_for_operation(operation, max_polls=MAX_POLLS)
  698. else:
  699. result = operation
  700. return result