config.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874
  1. import copy
  2. import json
  3. import logging
  4. import os
  5. import re
  6. import time
  7. from functools import partial, reduce
  8. import google_auth_httplib2
  9. import googleapiclient
  10. import httplib2
  11. from google.oauth2 import service_account
  12. from google.oauth2.credentials import Credentials as OAuthCredentials
  13. from googleapiclient import discovery, errors
  14. from ray._private.accelerators import TPUAcceleratorManager, tpu
  15. from ray.autoscaler._private.gcp.node import MAX_POLLS, POLL_INTERVAL, GCPNodeType
  16. from ray.autoscaler._private.util import (
  17. check_legacy_fields,
  18. generate_rsa_key_pair,
  19. generate_ssh_key_name,
  20. generate_ssh_key_paths,
  21. )
  22. logger = logging.getLogger(__name__)
  23. VERSION = "v1"
  24. TPU_VERSION = "v2alpha" # change once v2 is stable
  25. RAY = "ray-autoscaler"
  26. DEFAULT_SERVICE_ACCOUNT_ID = RAY + "-sa-" + VERSION
  27. SERVICE_ACCOUNT_EMAIL_TEMPLATE = "{account_id}@{project_id}.iam.gserviceaccount.com"
  28. DEFAULT_SERVICE_ACCOUNT_CONFIG = {
  29. "displayName": "Ray Autoscaler Service Account ({})".format(VERSION),
  30. }
  31. # Those roles will be always added.
  32. # NOTE: `serviceAccountUser` allows the head node to create workers with
  33. # a serviceAccount. `roleViewer` allows the head node to run bootstrap_gcp.
  34. DEFAULT_SERVICE_ACCOUNT_ROLES = [
  35. "roles/storage.objectAdmin",
  36. "roles/compute.admin",
  37. "roles/iam.serviceAccountUser",
  38. "roles/iam.roleViewer",
  39. ]
  40. # Those roles will only be added if there are TPU nodes defined in config.
  41. TPU_SERVICE_ACCOUNT_ROLES = ["roles/tpu.admin"]
  42. # If there are TPU nodes in config, this field will be set
  43. # to True in config["provider"].
  44. HAS_TPU_PROVIDER_FIELD = "_has_tpus"
  45. # NOTE: iam.serviceAccountUser allows the Head Node to create worker nodes
  46. # with ServiceAccounts.
  47. def tpu_accelerator_config_to_type(accelerator_config: dict) -> str:
  48. """Convert a provided accelerator_config to accelerator_type.
  49. Args:
  50. accelerator_config: A dictionary defining the spec of a
  51. TPU accelerator. The dictionary should consist of
  52. the keys 'type', indicating the TPU chip type, and
  53. 'topology', indicating the topology of the TPU.
  54. Returns:
  55. A string, accelerator_type, e.g. "v4-8".
  56. """
  57. generation = accelerator_config["type"].lower()
  58. topology = accelerator_config["topology"]
  59. # Reduce e.g. "2x2x2" to 8
  60. chip_dimensions = [int(chip_count) for chip_count in topology.split("x")]
  61. num_chips = reduce(lambda x, y: x * y, chip_dimensions)
  62. # V5LitePod is rendered as "V5LITE_POD" in accelerator configuration but
  63. # accelerator type uses a format like "v5litepod-{cores}", so we need
  64. # to manually convert the string here.
  65. if generation == "v5lite_pod":
  66. generation = "v5litepod"
  67. num_cores = tpu.get_tpu_cores_per_chip(generation) * num_chips
  68. return f"{generation}-{num_cores}"
  69. def _validate_tpu_config(node: dict):
  70. """Validate the provided node with TPU support.
  71. If the config is malformed, users will run into an error but this function
  72. will raise the error at config parsing time. This only tests very simple assertions.
  73. Raises: `ValueError` in case the input is malformed.
  74. """
  75. if "acceleratorType" in node and "acceleratorConfig" in node:
  76. raise ValueError(
  77. "For TPU usage, acceleratorType and acceleratorConfig "
  78. "cannot both be set."
  79. )
  80. if "acceleratorType" in node:
  81. accelerator_type = node["acceleratorType"]
  82. if not TPUAcceleratorManager.is_valid_tpu_accelerator_type(accelerator_type):
  83. raise ValueError(
  84. "`acceleratorType` should match v(generation)-(cores/chips). "
  85. f"Got {accelerator_type}."
  86. )
  87. else: # "acceleratorConfig" in node
  88. accelerator_config = node["acceleratorConfig"]
  89. if "type" not in accelerator_config or "topology" not in accelerator_config:
  90. raise ValueError(
  91. "acceleratorConfig expects 'type' and 'topology'. "
  92. f"Got {accelerator_config}"
  93. )
  94. generation = node["acceleratorConfig"]["type"]
  95. topology = node["acceleratorConfig"]["topology"]
  96. generation_pattern = re.compile(r"^V\d+[a-zA-Z]*$")
  97. topology_pattern = re.compile(r"^\d+x\d+(x\d+)?$")
  98. if generation != "V5LITE_POD" and not generation_pattern.match(generation):
  99. raise ValueError(f"type should match V(generation). Got {generation}.")
  100. if generation == "V2" or generation == "V3":
  101. raise ValueError(
  102. f"acceleratorConfig is not supported on V2/V3 TPUs. Got {generation}."
  103. )
  104. if not topology_pattern.match(topology):
  105. raise ValueError(
  106. f"topology should be of form axbxc or axb. Got {topology}."
  107. )
  108. def _get_num_tpu_chips(node: dict) -> int:
  109. chips = 0
  110. if "acceleratorType" in node:
  111. accelerator_type = node["acceleratorType"]
  112. # `acceleratorType` is typically v{generation}-{cores}
  113. cores = int(accelerator_type.split("-")[1])
  114. chips = cores / tpu.get_tpu_cores_per_chip(accelerator_type)
  115. if "acceleratorConfig" in node:
  116. topology = node["acceleratorConfig"]["topology"]
  117. # `topology` is typically {chips}x{chips}x{chips}
  118. # Multiply all dimensions together to get total number of chips
  119. chips = 1
  120. for dim in topology.split("x"):
  121. chips *= int(dim)
  122. return chips
  123. def _is_single_host_tpu(node: dict) -> bool:
  124. accelerator_type = ""
  125. if "acceleratorType" in node:
  126. accelerator_type = node["acceleratorType"]
  127. else:
  128. accelerator_type = tpu_accelerator_config_to_type(node["acceleratorConfig"])
  129. return _get_num_tpu_chips(node) <= tpu.get_num_tpu_visible_chips_per_host(
  130. accelerator_type
  131. )
  132. def get_node_type(node: dict) -> GCPNodeType:
  133. """Returns node type based on the keys in ``node``.
  134. This is a very simple check. If we have a ``machineType`` key,
  135. this is a Compute instance. If we don't have a ``machineType`` key,
  136. but we have ``acceleratorType``, this is a TPU. Otherwise, it's
  137. invalid and an exception is raised.
  138. This works for both node configs and API returned nodes.
  139. """
  140. if (
  141. "machineType" not in node
  142. and "acceleratorType" not in node
  143. and "acceleratorConfig" not in node
  144. ):
  145. raise ValueError(
  146. "Invalid node. For a Compute instance, 'machineType' is required."
  147. "For a TPU instance, 'acceleratorType' OR 'acceleratorConfig' and "
  148. f"no 'machineType' is required. Got {list(node)}."
  149. )
  150. if "machineType" not in node and (
  151. "acceleratorType" in node or "acceleratorConfig" in node
  152. ):
  153. _validate_tpu_config(node)
  154. if not _is_single_host_tpu(node):
  155. # Remove once proper autoscaling support is added.
  156. logger.warning(
  157. "TPU pod detected. Note that while the cluster launcher can create "
  158. "multiple TPU pods, proper autoscaling will not work as expected, "
  159. "as all hosts in a TPU pod need to execute the same program. "
  160. "Proceed with caution."
  161. )
  162. return GCPNodeType.TPU
  163. return GCPNodeType.COMPUTE
  164. def wait_for_crm_operation(operation, crm):
  165. """Poll for cloud resource manager operation until finished."""
  166. logger.info(
  167. "wait_for_crm_operation: "
  168. "Waiting for operation {} to finish...".format(operation)
  169. )
  170. for _ in range(MAX_POLLS):
  171. result = crm.operations().get(name=operation["name"]).execute()
  172. if "error" in result:
  173. raise Exception(result["error"])
  174. if "done" in result and result["done"]:
  175. logger.info("wait_for_crm_operation: Operation done.")
  176. break
  177. time.sleep(POLL_INTERVAL)
  178. return result
  179. def wait_for_compute_global_operation(project_name, operation, compute):
  180. """Poll for global compute operation until finished."""
  181. logger.info(
  182. "wait_for_compute_global_operation: "
  183. "Waiting for operation {} to finish...".format(operation["name"])
  184. )
  185. for _ in range(MAX_POLLS):
  186. result = (
  187. compute.globalOperations()
  188. .get(
  189. project=project_name,
  190. operation=operation["name"],
  191. )
  192. .execute()
  193. )
  194. if "error" in result:
  195. raise Exception(result["error"])
  196. if result["status"] == "DONE":
  197. logger.info("wait_for_compute_global_operation: Operation done.")
  198. break
  199. time.sleep(POLL_INTERVAL)
  200. return result
  201. def _has_tpus_in_node_configs(config: dict) -> bool:
  202. """Check if any nodes in config are TPUs."""
  203. node_configs = [
  204. node_type["node_config"]
  205. for node_type in config["available_node_types"].values()
  206. ]
  207. return any(get_node_type(node) == GCPNodeType.TPU for node in node_configs)
  208. def _is_head_node_a_tpu(config: dict) -> bool:
  209. """Check if the head node is a TPU."""
  210. node_configs = {
  211. node_id: node_type["node_config"]
  212. for node_id, node_type in config["available_node_types"].items()
  213. }
  214. return get_node_type(node_configs[config["head_node_type"]]) == GCPNodeType.TPU
  215. def build_request(http, *args, **kwargs):
  216. new_http = google_auth_httplib2.AuthorizedHttp(
  217. http.credentials, http=httplib2.Http()
  218. )
  219. return googleapiclient.http.HttpRequest(new_http, *args, **kwargs)
  220. def _create_crm(gcp_credentials=None):
  221. return discovery.build(
  222. "cloudresourcemanager",
  223. "v1",
  224. credentials=gcp_credentials,
  225. requestBuilder=build_request,
  226. cache_discovery=False,
  227. )
  228. def _create_iam(gcp_credentials=None):
  229. return discovery.build(
  230. "iam",
  231. "v1",
  232. credentials=gcp_credentials,
  233. requestBuilder=build_request,
  234. cache_discovery=False,
  235. )
  236. def _create_compute(gcp_credentials=None):
  237. return discovery.build(
  238. "compute",
  239. "v1",
  240. credentials=gcp_credentials,
  241. requestBuilder=build_request,
  242. cache_discovery=False,
  243. )
  244. def _create_tpu(gcp_credentials=None):
  245. return discovery.build(
  246. "tpu",
  247. TPU_VERSION,
  248. credentials=gcp_credentials,
  249. requestBuilder=build_request,
  250. cache_discovery=False,
  251. discoveryServiceUrl="https://tpu.googleapis.com/$discovery/rest",
  252. )
  253. def construct_clients_from_provider_config(provider_config):
  254. """
  255. Attempt to fetch and parse the JSON GCP credentials from the provider
  256. config yaml file.
  257. tpu resource (the last element of the tuple) will be None if
  258. `_has_tpus` in provider config is not set or False.
  259. """
  260. gcp_credentials = provider_config.get("gcp_credentials")
  261. if gcp_credentials is None:
  262. logger.debug(
  263. "gcp_credentials not found in cluster yaml file. "
  264. "Falling back to GOOGLE_APPLICATION_CREDENTIALS "
  265. "environment variable."
  266. )
  267. tpu_resource = (
  268. _create_tpu()
  269. if provider_config.get(HAS_TPU_PROVIDER_FIELD, False)
  270. else None
  271. )
  272. # If gcp_credentials is None, then discovery.build will search for
  273. # credentials in the local environment.
  274. return _create_crm(), _create_iam(), _create_compute(), tpu_resource
  275. assert (
  276. "type" in gcp_credentials
  277. ), "gcp_credentials cluster yaml field missing 'type' field."
  278. assert (
  279. "credentials" in gcp_credentials
  280. ), "gcp_credentials cluster yaml field missing 'credentials' field."
  281. cred_type = gcp_credentials["type"]
  282. credentials_field = gcp_credentials["credentials"]
  283. if cred_type == "service_account":
  284. # If parsing the gcp_credentials failed, then the user likely made a
  285. # mistake in copying the credentials into the config yaml.
  286. try:
  287. service_account_info = json.loads(credentials_field)
  288. except json.decoder.JSONDecodeError:
  289. raise RuntimeError(
  290. "gcp_credentials found in cluster yaml file but "
  291. "formatted improperly."
  292. )
  293. credentials = service_account.Credentials.from_service_account_info(
  294. service_account_info
  295. )
  296. elif cred_type == "credentials_token":
  297. # Otherwise the credentials type must be credentials_token.
  298. credentials = OAuthCredentials(credentials_field)
  299. tpu_resource = (
  300. _create_tpu(credentials)
  301. if provider_config.get(HAS_TPU_PROVIDER_FIELD, False)
  302. else None
  303. )
  304. return (
  305. _create_crm(credentials),
  306. _create_iam(credentials),
  307. _create_compute(credentials),
  308. tpu_resource,
  309. )
  310. def bootstrap_gcp(config):
  311. config = copy.deepcopy(config)
  312. check_legacy_fields(config)
  313. # Used internally to store head IAM role.
  314. config["head_node"] = {}
  315. # Check if we have any TPUs defined, and if so,
  316. # insert that information into the provider config
  317. if _has_tpus_in_node_configs(config):
  318. config["provider"][HAS_TPU_PROVIDER_FIELD] = True
  319. crm, iam, compute, tpu = construct_clients_from_provider_config(config["provider"])
  320. config = _configure_project(config, crm)
  321. config = _configure_iam_role(config, crm, iam)
  322. config = _configure_key_pair(config, compute)
  323. config = _configure_subnet(config, compute)
  324. return config
  325. def _configure_project(config, crm):
  326. """Setup a Google Cloud Platform Project.
  327. Google Compute Platform organizes all the resources, such as storage
  328. buckets, users, and instances under projects. This is different from
  329. aws ec2 where everything is global.
  330. """
  331. config = copy.deepcopy(config)
  332. project_id = config["provider"].get("project_id")
  333. assert config["provider"]["project_id"] is not None, (
  334. "'project_id' must be set in the 'provider' section of the autoscaler"
  335. " config. Notice that the project id must be globally unique."
  336. )
  337. project = _get_project(project_id, crm)
  338. if project is None:
  339. # Project not found, try creating it
  340. _create_project(project_id, crm)
  341. project = _get_project(project_id, crm)
  342. assert project is not None, "Failed to create project"
  343. assert (
  344. project["lifecycleState"] == "ACTIVE"
  345. ), "Project status needs to be ACTIVE, got {}".format(project["lifecycleState"])
  346. config["provider"]["project_id"] = project["projectId"]
  347. return config
  348. def _configure_iam_role(config, crm, iam):
  349. """Setup a gcp service account with IAM roles.
  350. Creates a gcp service acconut and binds IAM roles which allow it to control
  351. control storage/compute services. Specifically, the head node needs to have
  352. an IAM role that allows it to create further gce instances and store items
  353. in google cloud storage.
  354. TODO: Allow the name/id of the service account to be configured
  355. """
  356. config = copy.deepcopy(config)
  357. email = SERVICE_ACCOUNT_EMAIL_TEMPLATE.format(
  358. account_id=DEFAULT_SERVICE_ACCOUNT_ID,
  359. project_id=config["provider"]["project_id"],
  360. )
  361. service_account = _get_service_account(email, config, iam)
  362. if service_account is None:
  363. logger.info(
  364. "_configure_iam_role: "
  365. "Creating new service account {}".format(DEFAULT_SERVICE_ACCOUNT_ID)
  366. )
  367. service_account = _create_service_account(
  368. DEFAULT_SERVICE_ACCOUNT_ID, DEFAULT_SERVICE_ACCOUNT_CONFIG, config, iam
  369. )
  370. assert service_account is not None, "Failed to create service account"
  371. if config["provider"].get(HAS_TPU_PROVIDER_FIELD, False):
  372. roles = DEFAULT_SERVICE_ACCOUNT_ROLES + TPU_SERVICE_ACCOUNT_ROLES
  373. else:
  374. roles = DEFAULT_SERVICE_ACCOUNT_ROLES
  375. _add_iam_policy_binding(service_account, roles, crm)
  376. config["head_node"]["serviceAccounts"] = [
  377. {
  378. "email": service_account["email"],
  379. # NOTE: The amount of access is determined by the scope + IAM
  380. # role of the service account. Even if the cloud-platform scope
  381. # gives (scope) access to the whole cloud-platform, the service
  382. # account is limited by the IAM rights specified below.
  383. "scopes": ["https://www.googleapis.com/auth/cloud-platform"],
  384. }
  385. ]
  386. return config
  387. def _configure_key_pair(config, compute):
  388. """Configure SSH access, using an existing key pair if possible.
  389. Creates a project-wide ssh key that can be used to access all the instances
  390. unless explicitly prohibited by instance config.
  391. The ssh-keys created by ray are of format:
  392. [USERNAME]:ssh-rsa [KEY_VALUE] [USERNAME]
  393. where:
  394. [USERNAME] is the user for the SSH key, specified in the config.
  395. [KEY_VALUE] is the public SSH key value.
  396. """
  397. config = copy.deepcopy(config)
  398. if "ssh_private_key" in config["auth"]:
  399. return config
  400. ssh_user = config["auth"]["ssh_user"]
  401. project = compute.projects().get(project=config["provider"]["project_id"]).execute()
  402. # Key pairs associated with project meta data. The key pairs are general,
  403. # and not just ssh keys.
  404. ssh_keys_str = next(
  405. (
  406. item
  407. for item in project["commonInstanceMetadata"].get("items", [])
  408. if item["key"] == "ssh-keys"
  409. ),
  410. {},
  411. ).get("value", "")
  412. ssh_keys = ssh_keys_str.split("\n") if ssh_keys_str else []
  413. # Try a few times to get or create a good key pair.
  414. key_found = False
  415. for i in range(10):
  416. key_name = generate_ssh_key_name(
  417. "gcp",
  418. i,
  419. config["provider"]["region"],
  420. config["provider"]["project_id"],
  421. ssh_user,
  422. )
  423. public_key_path, private_key_path = generate_ssh_key_paths(key_name)
  424. for ssh_key in ssh_keys:
  425. key_parts = ssh_key.split(" ")
  426. if len(key_parts) != 3:
  427. continue
  428. if key_parts[2] == ssh_user and os.path.exists(private_key_path):
  429. # Found a key
  430. key_found = True
  431. break
  432. # Writing the new ssh key to the filesystem fails if the ~/.ssh
  433. # directory doesn't already exist.
  434. os.makedirs(os.path.expanduser("~/.ssh"), exist_ok=True)
  435. # Create a key since it doesn't exist locally or in GCP
  436. if not key_found and not os.path.exists(private_key_path):
  437. logger.info(
  438. "_configure_key_pair: Creating new key pair {}".format(key_name)
  439. )
  440. public_key, private_key = generate_rsa_key_pair()
  441. for attempt in range(MAX_POLLS):
  442. try:
  443. _create_project_ssh_key_pair(project, public_key, ssh_user, compute)
  444. break
  445. except errors.HttpError as e:
  446. if e.resp.status != 412 or attempt == MAX_POLLS - 1:
  447. raise
  448. logger.warning(
  449. "GCP project metadata update conflict for %s (%s); retrying",
  450. config["provider"]["project_id"],
  451. e,
  452. )
  453. time.sleep(POLL_INTERVAL)
  454. project = (
  455. compute.projects()
  456. .get(project=config["provider"]["project_id"])
  457. .execute()
  458. )
  459. # Create the directory if it doesn't exists
  460. private_key_dir = os.path.dirname(private_key_path)
  461. os.makedirs(private_key_dir, exist_ok=True)
  462. # We need to make sure to _create_ the file with the right
  463. # permissions. In order to do that we need to change the default
  464. # os.open behavior to include the mode we want.
  465. with open(
  466. private_key_path,
  467. "w",
  468. opener=partial(os.open, mode=0o600),
  469. ) as f:
  470. f.write(private_key)
  471. with open(public_key_path, "w") as f:
  472. f.write(public_key)
  473. key_found = True
  474. break
  475. if key_found:
  476. break
  477. assert key_found, "SSH keypair for user {} not found for {}".format(
  478. ssh_user, private_key_path
  479. )
  480. assert os.path.exists(
  481. private_key_path
  482. ), "Private key file {} not found for user {}".format(private_key_path, ssh_user)
  483. logger.info(
  484. "_configure_key_pair: "
  485. "Private key not specified in config, using"
  486. "{}".format(private_key_path)
  487. )
  488. config["auth"]["ssh_private_key"] = private_key_path
  489. return config
  490. def _configure_subnet(config, compute):
  491. """Pick a reasonable subnet if not specified by the config."""
  492. config = copy.deepcopy(config)
  493. node_configs = [
  494. node_type["node_config"]
  495. for node_type in config["available_node_types"].values()
  496. ]
  497. # Rationale: avoid subnet lookup if the network is already
  498. # completely manually configured
  499. # networkInterfaces is compute, networkConfig is TPU
  500. if all(
  501. "networkInterfaces" in node_config or "networkConfig" in node_config
  502. for node_config in node_configs
  503. ):
  504. return config
  505. subnets = _list_subnets(config, compute)
  506. if not subnets:
  507. raise NotImplementedError("Should be able to create subnet.")
  508. # TODO: make sure that we have usable subnet. Maybe call
  509. # compute.subnetworks().listUsable? For some reason it didn't
  510. # work out-of-the-box
  511. default_subnet = subnets[0]
  512. default_interfaces = [
  513. {
  514. "subnetwork": default_subnet["selfLink"],
  515. "accessConfigs": [
  516. {
  517. "name": "External NAT",
  518. "type": "ONE_TO_ONE_NAT",
  519. }
  520. ],
  521. }
  522. ]
  523. for node_config in node_configs:
  524. # The not applicable key will be removed during node creation
  525. # compute
  526. if "networkInterfaces" not in node_config:
  527. node_config["networkInterfaces"] = copy.deepcopy(default_interfaces)
  528. # TPU
  529. if "networkConfig" not in node_config:
  530. node_config["networkConfig"] = copy.deepcopy(default_interfaces)[0]
  531. node_config["networkConfig"].pop("accessConfigs")
  532. return config
  533. def _list_subnets(config, compute):
  534. response = (
  535. compute.subnetworks()
  536. .list(
  537. project=config["provider"]["project_id"],
  538. region=config["provider"]["region"],
  539. )
  540. .execute()
  541. )
  542. return response["items"]
  543. def _get_subnet(config, subnet_id, compute):
  544. subnet = (
  545. compute.subnetworks()
  546. .get(
  547. project=config["provider"]["project_id"],
  548. region=config["provider"]["region"],
  549. subnetwork=subnet_id,
  550. )
  551. .execute()
  552. )
  553. return subnet
  554. def _get_project(project_id, crm):
  555. try:
  556. project = crm.projects().get(projectId=project_id).execute()
  557. except errors.HttpError as e:
  558. if e.resp.status != 403:
  559. raise
  560. project = None
  561. return project
  562. def _create_project(project_id, crm):
  563. operation = (
  564. crm.projects()
  565. .create(body={"projectId": project_id, "name": project_id})
  566. .execute()
  567. )
  568. result = wait_for_crm_operation(operation, crm)
  569. return result
  570. def _get_service_account(account, config, iam):
  571. project_id = config["provider"]["project_id"]
  572. full_name = "projects/{project_id}/serviceAccounts/{account}".format(
  573. project_id=project_id, account=account
  574. )
  575. try:
  576. service_account = iam.projects().serviceAccounts().get(name=full_name).execute()
  577. except errors.HttpError as e:
  578. if e.resp.status != 404:
  579. raise
  580. service_account = None
  581. return service_account
  582. def _create_service_account(account_id, account_config, config, iam):
  583. project_id = config["provider"]["project_id"]
  584. service_account = (
  585. iam.projects()
  586. .serviceAccounts()
  587. .create(
  588. name="projects/{project_id}".format(project_id=project_id),
  589. body={
  590. "accountId": account_id,
  591. "serviceAccount": account_config,
  592. },
  593. )
  594. .execute()
  595. )
  596. return service_account
  597. def _add_iam_policy_binding(service_account, roles, crm):
  598. """Add new IAM roles for the service account."""
  599. project_id = service_account["projectId"]
  600. email = service_account["email"]
  601. member_id = "serviceAccount:" + email
  602. policy = (
  603. crm.projects()
  604. .getIamPolicy(
  605. resource=project_id, body={"options": {"requestedPolicyVersion": 3}}
  606. )
  607. .execute()
  608. )
  609. already_configured = True
  610. for role in roles:
  611. role_exists = False
  612. for binding in policy["bindings"]:
  613. if binding["role"] == role:
  614. if member_id not in binding["members"]:
  615. binding["members"].append(member_id)
  616. already_configured = False
  617. role_exists = True
  618. if not role_exists:
  619. already_configured = False
  620. policy["bindings"].append(
  621. {
  622. "members": [member_id],
  623. "role": role,
  624. }
  625. )
  626. if already_configured:
  627. # In some managed environments, an admin needs to grant the
  628. # roles, so only call setIamPolicy if needed.
  629. return
  630. result = (
  631. crm.projects()
  632. .setIamPolicy(
  633. resource=project_id,
  634. body={
  635. "policy": policy,
  636. },
  637. )
  638. .execute()
  639. )
  640. return result
  641. def _create_project_ssh_key_pair(project, public_key, ssh_user, compute):
  642. """Inserts an ssh-key into project commonInstanceMetadata"""
  643. key_parts = public_key.split(" ")
  644. # Sanity checks to make sure that the generated key matches expectation
  645. assert len(key_parts) == 2, key_parts
  646. assert key_parts[0] == "ssh-rsa", key_parts
  647. new_ssh_meta = "{ssh_user}:ssh-rsa {key_value} {ssh_user}".format(
  648. ssh_user=ssh_user, key_value=key_parts[1]
  649. )
  650. common_instance_metadata = project["commonInstanceMetadata"]
  651. items = common_instance_metadata.get("items", [])
  652. ssh_keys_i = next(
  653. (i for i, item in enumerate(items) if item["key"] == "ssh-keys"), None
  654. )
  655. if ssh_keys_i is None:
  656. items.append({"key": "ssh-keys", "value": new_ssh_meta})
  657. else:
  658. ssh_keys = items[ssh_keys_i]
  659. ssh_keys["value"] += "\n" + new_ssh_meta
  660. items[ssh_keys_i] = ssh_keys
  661. common_instance_metadata["items"] = items
  662. try:
  663. operation = (
  664. compute.projects()
  665. .setCommonInstanceMetadata(
  666. project=project["name"], body=common_instance_metadata
  667. )
  668. .execute()
  669. )
  670. response = wait_for_compute_global_operation(
  671. project["name"], operation, compute
  672. )
  673. except errors.HttpError as e: # Only trim when explicitly opted in.
  674. if (
  675. e.resp.status != 413
  676. or ssh_keys_i is None
  677. or os.environ.get("RAY_TRIM_AUTOSCALER_SSH_KEYS") != "1"
  678. ):
  679. raise
  680. logger.warning(
  681. "GCP project metadata size limit reached for %s (%s); "
  682. "removing half of ssh keys and retrying",
  683. project["name"],
  684. e,
  685. )
  686. ssh_keys_value = items[ssh_keys_i].get("value", "")
  687. ssh_keys = [key for key in ssh_keys_value.splitlines() if key.strip()]
  688. trim_start = ( # If our new key is the only one, this preserves it.
  689. len(ssh_keys) // 2
  690. )
  691. items[ssh_keys_i]["value"] = "\n".join(ssh_keys[trim_start:])
  692. common_instance_metadata["items"] = items
  693. operation = (
  694. compute.projects()
  695. .setCommonInstanceMetadata(
  696. project=project["name"], body=common_instance_metadata
  697. )
  698. .execute()
  699. )
  700. response = wait_for_compute_global_operation(
  701. project["name"], operation, compute
  702. )
  703. return response