config.py 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226
  1. import copy
  2. import itertools
  3. import json
  4. import logging
  5. import os
  6. import time
  7. from collections import Counter
  8. from functools import lru_cache, partial
  9. from typing import Any, Dict, List, Optional, Set, Tuple
  10. import boto3
  11. import botocore
  12. from packaging.version import Version
  13. from ray.autoscaler._private.aws.cloudwatch.cloudwatch_helper import (
  14. CloudwatchHelper as cwh,
  15. )
  16. from ray.autoscaler._private.aws.utils import (
  17. LazyDefaultDict,
  18. handle_boto_error,
  19. resource_cache,
  20. )
  21. from ray.autoscaler._private.cli_logger import cf, cli_logger
  22. from ray.autoscaler._private.event_system import CreateClusterEvent, global_event_system
  23. from ray.autoscaler._private.providers import _PROVIDER_PRETTY_NAMES
  24. from ray.autoscaler._private.util import check_legacy_fields
  25. from ray.autoscaler.tags import NODE_TYPE_LEGACY_HEAD, NODE_TYPE_LEGACY_WORKER
  26. logger = logging.getLogger(__name__)
  27. RAY = "ray-autoscaler"
  28. DEFAULT_RAY_INSTANCE_PROFILE = RAY + "-v1"
  29. DEFAULT_RAY_IAM_ROLE = RAY + "-v1"
  30. SECURITY_GROUP_TEMPLATE = RAY + "-{}"
  31. # V71.0 has CUDA 11.2
  32. DEFAULT_AMI_NAME = "AWS Deep Learning AMI (Ubuntu 18.04) V71.0"
  33. # Obtained with:
  34. # for region in $(aws ec2 describe-regions --query "Regions[].RegionName" --output text); do
  35. # echo "Region: $region";
  36. # aws ec2 describe-images \
  37. # --region $region \
  38. # --filters "Name=name,Values=Deep Learning AMI (Ubuntu 18.04) Version 71.0*" \
  39. # --query "Images[].[ImageId, Name]" \
  40. # --output text;
  41. # echo "--------------------------";
  42. # done
  43. # TODO(alex) : write a unit test to make sure we update AMI version used in
  44. # ray/autoscaler/aws/example-full.yaml whenever we update this dict.
  45. DEFAULT_AMI = {
  46. "us-east-1": "ami-0271ce88f6c03e149", # US East (N. Virginia)
  47. "us-east-2": "ami-0ce21b0ce8e0f5a37", # US East (Ohio)
  48. "us-west-1": "ami-030896954b2b72361", # US West (N. California)
  49. "us-west-2": "ami-0a2b85e15b7c0ac34", # US West (Oregon)
  50. "ca-central-1": "ami-044b98bea12677052", # Canada (Central)
  51. "eu-central-1": "ami-004d285ecf53fb335", # EU (Frankfurt)
  52. "eu-west-1": "ami-0e17ed40febe794a6", # EU (Ireland)
  53. "eu-west-2": "ami-0ad9e6b9d4e22df1a", # EU (London)
  54. "eu-west-3": "ami-05698775025146698", # EU (Paris)
  55. "sa-east-1": "ami-004b19c7ed8d790bc", # SA (Sao Paulo)
  56. "ap-northeast-1": "ami-03e45b1f18378ea97", # Asia Pacific (Tokyo)
  57. "ap-northeast-2": "ami-095bae3e623e2ff99", # Asia Pacific (Seoul)
  58. "ap-northeast-3": "ami-02fd0ba69c1ef37ad", # Asia Pacific (Osaka)
  59. "ap-southeast-1": "ami-03ef1223872072e95", # Asia Pacific (Singapore)
  60. "ap-southeast-2": "ami-0bcf4f75e44e0a866", # Asia Pacific (Sydney)
  61. }
  62. # todo: cli_logger should handle this assert properly
  63. # this should probably also happens somewhere else
  64. assert Version(boto3.__version__) >= Version(
  65. "1.4.8"
  66. ), "Boto3 version >= 1.4.8 required, try `pip install -U boto3`"
  67. def key_pair(i, region, key_name):
  68. """
  69. If key_name is not None, key_pair will be named after key_name.
  70. Returns the ith default (aws_key_pair_name, key_pair_path).
  71. """
  72. if i == 0:
  73. key_pair_name = "{}_{}".format(RAY, region) if key_name is None else key_name
  74. return (
  75. key_pair_name,
  76. os.path.expanduser("~/.ssh/{}.pem".format(key_pair_name)),
  77. )
  78. key_pair_name = (
  79. "{}_{}_{}".format(RAY, i, region)
  80. if key_name is None
  81. else key_name + "_key-{}".format(i)
  82. )
  83. return (key_pair_name, os.path.expanduser("~/.ssh/{}.pem".format(key_pair_name)))
  84. # Suppress excessive connection dropped logs from boto
  85. logging.getLogger("botocore").setLevel(logging.WARNING)
  86. _log_info = {}
  87. def reload_log_state(override_log_info):
  88. _log_info.update(override_log_info)
  89. def get_log_state():
  90. return _log_info.copy()
  91. def _set_config_info(**kwargs):
  92. """Record configuration artifacts useful for logging."""
  93. # todo: this is technically fragile iff we ever use multiple configs
  94. for k, v in kwargs.items():
  95. _log_info[k] = v
  96. def _arn_to_name(arn):
  97. return arn.split(":")[-1].split("/")[-1]
  98. def log_to_cli(config: Dict[str, Any]) -> None:
  99. provider_name = _PROVIDER_PRETTY_NAMES.get("aws", None)
  100. cli_logger.doassert(
  101. provider_name is not None, "Could not find a pretty name for the AWS provider."
  102. )
  103. head_node_type = config["head_node_type"]
  104. head_node_config = config["available_node_types"][head_node_type]["node_config"]
  105. with cli_logger.group("{} config", provider_name):
  106. def print_info(
  107. resource_string: str,
  108. key: str,
  109. src_key: str,
  110. allowed_tags: Optional[List[str]] = None,
  111. list_value: bool = False,
  112. ) -> None:
  113. if allowed_tags is None:
  114. allowed_tags = ["default"]
  115. node_tags = {}
  116. # set of configurations corresponding to `key`
  117. unique_settings = set()
  118. for node_type_key, node_type in config["available_node_types"].items():
  119. node_tags[node_type_key] = {}
  120. tag = _log_info[src_key][node_type_key]
  121. if tag in allowed_tags:
  122. node_tags[node_type_key][tag] = True
  123. setting = node_type["node_config"].get(key)
  124. if list_value:
  125. unique_settings.add(tuple(setting))
  126. else:
  127. unique_settings.add(setting)
  128. head_value_str = head_node_config[key]
  129. if list_value:
  130. head_value_str = cli_logger.render_list(head_value_str)
  131. if len(unique_settings) == 1:
  132. # all node types are configured the same, condense
  133. # log output
  134. cli_logger.labeled_value(
  135. resource_string + " (all available node types)",
  136. "{}",
  137. head_value_str,
  138. _tags=node_tags[config["head_node_type"]],
  139. )
  140. else:
  141. # do head node type first
  142. cli_logger.labeled_value(
  143. resource_string + f" ({head_node_type})",
  144. "{}",
  145. head_value_str,
  146. _tags=node_tags[head_node_type],
  147. )
  148. # go through remaining types
  149. for node_type_key, node_type in config["available_node_types"].items():
  150. if node_type_key == head_node_type:
  151. continue
  152. workers_value_str = node_type["node_config"][key]
  153. if list_value:
  154. workers_value_str = cli_logger.render_list(workers_value_str)
  155. cli_logger.labeled_value(
  156. resource_string + f" ({node_type_key})",
  157. "{}",
  158. workers_value_str,
  159. _tags=node_tags[node_type_key],
  160. )
  161. tags = {"default": _log_info["head_instance_profile_src"] == "default"}
  162. # head_node_config is the head_node_type's config,
  163. # config["head_node"] is a field that gets applied only to the actual
  164. # head node (and not workers of the head's node_type)
  165. assert (
  166. "IamInstanceProfile" in head_node_config
  167. or "IamInstanceProfile" in config["head_node"]
  168. )
  169. if "IamInstanceProfile" in head_node_config:
  170. # If the user manually configured the role we're here.
  171. IamProfile = head_node_config["IamInstanceProfile"]
  172. elif "IamInstanceProfile" in config["head_node"]:
  173. # If we filled the default IAM role, we're here.
  174. IamProfile = config["head_node"]["IamInstanceProfile"]
  175. profile_arn = IamProfile.get("Arn")
  176. profile_name = _arn_to_name(profile_arn) if profile_arn else IamProfile["Name"]
  177. cli_logger.labeled_value("IAM Profile", "{}", profile_name, _tags=tags)
  178. if all(
  179. "KeyName" in node_type["node_config"]
  180. for node_type in config["available_node_types"].values()
  181. ):
  182. print_info("EC2 Key pair", "KeyName", "keypair_src")
  183. print_info("VPC Subnets", "SubnetIds", "subnet_src", list_value=True)
  184. print_info(
  185. "EC2 Security groups",
  186. "SecurityGroupIds",
  187. "security_group_src",
  188. list_value=True,
  189. )
  190. print_info("EC2 AMI", "ImageId", "ami_src", allowed_tags=["dlami"])
  191. cli_logger.newline()
  192. def bootstrap_aws(config):
  193. # create a copy of the input config to modify
  194. config = copy.deepcopy(config)
  195. # Log warnings if user included deprecated `head_node` or `worker_nodes`
  196. # fields. Raise error if no `available_node_types`
  197. check_legacy_fields(config)
  198. # Used internally to store head IAM role.
  199. config["head_node"] = {}
  200. # If a LaunchTemplate is provided, extract the necessary fields for the
  201. # config stages below.
  202. config = _configure_from_launch_template(config)
  203. # If NetworkInterfaces are provided, extract the necessary fields for the
  204. # config stages below.
  205. config = _configure_from_network_interfaces(config)
  206. # The head node needs to have an IAM role that allows it to create further
  207. # EC2 instances.
  208. config = _configure_iam_role(config)
  209. # Configure SSH access, using an existing key pair if possible.
  210. config = _configure_key_pair(config)
  211. global_event_system.execute_callback(
  212. CreateClusterEvent.ssh_keypair_downloaded,
  213. {"ssh_key_path": config["auth"]["ssh_private_key"]},
  214. )
  215. # Pick a reasonable subnet if not specified by the user.
  216. config = _configure_subnet(config)
  217. # Cluster workers should be in a security group that permits traffic within
  218. # the group, and also SSH access from outside.
  219. config = _configure_security_group(config)
  220. # Provide a helpful message for missing AMI.
  221. _check_ami(config)
  222. return config
  223. def _configure_iam_role(config):
  224. head_node_type = config["head_node_type"]
  225. head_node_config = config["available_node_types"][head_node_type]["node_config"]
  226. if "IamInstanceProfile" in head_node_config:
  227. _set_config_info(head_instance_profile_src="config")
  228. return config
  229. _set_config_info(head_instance_profile_src="default")
  230. instance_profile_name = cwh.resolve_instance_profile_name(
  231. config["provider"],
  232. DEFAULT_RAY_INSTANCE_PROFILE,
  233. )
  234. profile = _get_instance_profile(instance_profile_name, config)
  235. if profile is None:
  236. cli_logger.verbose(
  237. "Creating new IAM instance profile {} for use as the default.",
  238. cf.bold(instance_profile_name),
  239. )
  240. client = _client("iam", config)
  241. client.create_instance_profile(InstanceProfileName=instance_profile_name)
  242. profile = _get_instance_profile(instance_profile_name, config)
  243. time.sleep(15) # wait for propagation
  244. cli_logger.doassert(
  245. profile is not None, "Failed to create instance profile."
  246. ) # todo: err msg
  247. assert profile is not None, "Failed to create instance profile"
  248. if not profile.roles:
  249. role_name = cwh.resolve_iam_role_name(config["provider"], DEFAULT_RAY_IAM_ROLE)
  250. role = _get_role(role_name, config)
  251. if role is None:
  252. cli_logger.verbose(
  253. "Creating new IAM role {} for use as the default instance role.",
  254. cf.bold(role_name),
  255. )
  256. iam = _resource("iam", config)
  257. policy_doc = {
  258. "Statement": [
  259. {
  260. "Effect": "Allow",
  261. "Principal": {"Service": "ec2.amazonaws.com"},
  262. "Action": "sts:AssumeRole",
  263. },
  264. ]
  265. }
  266. attach_policy_arns = cwh.resolve_policy_arns(
  267. config["provider"],
  268. iam,
  269. [
  270. "arn:aws:iam::aws:policy/AmazonEC2FullAccess",
  271. "arn:aws:iam::aws:policy/AmazonS3FullAccess",
  272. ],
  273. )
  274. iam.create_role(
  275. RoleName=role_name, AssumeRolePolicyDocument=json.dumps(policy_doc)
  276. )
  277. role = _get_role(role_name, config)
  278. cli_logger.doassert(
  279. role is not None, "Failed to create role."
  280. ) # todo: err msg
  281. assert role is not None, "Failed to create role"
  282. for policy_arn in attach_policy_arns:
  283. role.attach_policy(PolicyArn=policy_arn)
  284. profile.add_role(RoleName=role.name)
  285. time.sleep(15) # wait for propagation
  286. # Add IAM role to "head_node" field so that it is applied only to
  287. # the head node -- not to workers with the same node type as the head.
  288. config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}
  289. return config
  290. def _configure_key_pair(config):
  291. node_types = config["available_node_types"]
  292. # map from node type key -> source of KeyName field
  293. key_pair_src_info = {}
  294. _set_config_info(keypair_src=key_pair_src_info)
  295. if "ssh_private_key" in config["auth"]:
  296. for node_type_key in node_types:
  297. # keypairs should be provided in the config
  298. key_pair_src_info[node_type_key] = "config"
  299. # If the key is not configured via the cloudinit
  300. # UserData, it should be configured via KeyName or
  301. # else we will risk starting a node that we cannot
  302. # SSH into:
  303. for node_type in node_types:
  304. node_config = node_types[node_type]["node_config"]
  305. if "UserData" not in node_config:
  306. cli_logger.doassert(
  307. "KeyName" in node_config, _key_assert_msg(node_type)
  308. )
  309. assert "KeyName" in node_config
  310. return config
  311. for node_type_key in node_types:
  312. key_pair_src_info[node_type_key] = "default"
  313. ec2 = _resource("ec2", config)
  314. # Writing the new ssh key to the filesystem fails if the ~/.ssh
  315. # directory doesn't already exist.
  316. os.makedirs(os.path.expanduser("~/.ssh"), exist_ok=True)
  317. # Try a few times to get or create a good key pair.
  318. MAX_NUM_KEYS = 600
  319. for i in range(MAX_NUM_KEYS):
  320. key_name = config["provider"].get("key_pair", {}).get("key_name")
  321. key_name, key_path = key_pair(i, config["provider"]["region"], key_name)
  322. key = _get_key(key_name, config)
  323. # Found a good key.
  324. if key and os.path.exists(key_path):
  325. break
  326. # We can safely create a new key.
  327. if not key and not os.path.exists(key_path):
  328. cli_logger.verbose(
  329. "Creating new key pair {} for use as the default.", cf.bold(key_name)
  330. )
  331. key = ec2.create_key_pair(KeyName=key_name)
  332. # We need to make sure to _create_ the file with the right
  333. # permissions. In order to do that we need to change the default
  334. # os.open behavior to include the mode we want.
  335. with open(key_path, "w", opener=partial(os.open, mode=0o600)) as f:
  336. f.write(key.key_material)
  337. break
  338. if not key:
  339. cli_logger.abort(
  340. "No matching local key file for any of the key pairs in this "
  341. "account with ids from 0..{}. "
  342. "Consider deleting some unused keys pairs from your account.",
  343. key_name,
  344. )
  345. cli_logger.doassert(
  346. os.path.exists(key_path),
  347. "Private key file " + cf.bold("{}") + " not found for " + cf.bold("{}"),
  348. key_path,
  349. key_name,
  350. ) # todo: err msg
  351. assert os.path.exists(key_path), "Private key file {} not found for {}".format(
  352. key_path, key_name
  353. )
  354. config["auth"]["ssh_private_key"] = key_path
  355. for node_type in node_types.values():
  356. node_config = node_type["node_config"]
  357. node_config["KeyName"] = key_name
  358. return config
  359. def _key_assert_msg(node_type: str) -> str:
  360. if node_type == NODE_TYPE_LEGACY_WORKER:
  361. return "`KeyName` missing for worker nodes."
  362. elif node_type == NODE_TYPE_LEGACY_HEAD:
  363. return "`KeyName` missing for head node."
  364. else:
  365. return (
  366. "`KeyName` missing from the `node_config` of" f" node type `{node_type}`."
  367. )
  368. def _usable_subnet_ids(
  369. user_specified_subnets: Optional[List[Any]],
  370. all_subnets: List[Any],
  371. azs: Optional[str],
  372. vpc_id_of_sg: Optional[str],
  373. use_internal_ips: bool,
  374. node_type_key: str,
  375. ) -> Tuple[List[str], str]:
  376. """Prunes subnets down to those that meet the following criteria.
  377. Subnets must be:
  378. * 'Available' according to AWS.
  379. * Public, unless `use_internal_ips` is specified.
  380. * In one of the AZs, if AZs are provided.
  381. * In the given VPC, if a VPC is specified for Security Groups.
  382. Returns:
  383. List[str]: Subnets that are usable.
  384. str: VPC ID of the first subnet.
  385. """
  386. def _are_user_subnets_pruned(current_subnets: List[Any]) -> bool:
  387. return user_specified_subnets is not None and len(current_subnets) != len(
  388. user_specified_subnets
  389. )
  390. def _get_pruned_subnets(current_subnets: List[Any]) -> Set[str]:
  391. current_subnet_ids = {s.subnet_id for s in current_subnets}
  392. user_specified_subnet_ids = {s.subnet_id for s in user_specified_subnets}
  393. return user_specified_subnet_ids - current_subnet_ids
  394. try:
  395. candidate_subnets = (
  396. user_specified_subnets
  397. if user_specified_subnets is not None
  398. else all_subnets
  399. )
  400. if vpc_id_of_sg:
  401. candidate_subnets = [
  402. s for s in candidate_subnets if s.vpc_id == vpc_id_of_sg
  403. ]
  404. subnets = sorted(
  405. (
  406. s
  407. for s in candidate_subnets
  408. if s.state == "available"
  409. and (use_internal_ips or s.map_public_ip_on_launch)
  410. ),
  411. reverse=True, # sort from Z-A
  412. key=lambda subnet: subnet.availability_zone,
  413. )
  414. except botocore.exceptions.ClientError as exc:
  415. handle_boto_error(exc, "Failed to fetch available subnets from AWS.")
  416. raise exc
  417. if not subnets:
  418. cli_logger.abort(
  419. f"No usable subnets found for node type {node_type_key}, try "
  420. "manually creating an instance in your specified region to "
  421. "populate the list of subnets and trying this again.\n"
  422. "Note that the subnet must map public IPs "
  423. "on instance launch unless you set `use_internal_ips: true` in "
  424. "the `provider` config."
  425. )
  426. elif _are_user_subnets_pruned(subnets):
  427. cli_logger.abort(
  428. f"The specified subnets for node type {node_type_key} are not "
  429. f"usable: {_get_pruned_subnets(subnets)}"
  430. )
  431. if azs is not None:
  432. azs = [az.strip() for az in azs.split(",")]
  433. subnets = [
  434. s
  435. for az in azs # Iterate over AZs first to maintain the ordering
  436. for s in subnets
  437. if s.availability_zone == az
  438. ]
  439. if not subnets:
  440. cli_logger.abort(
  441. f"No usable subnets matching availability zone {azs} found "
  442. f"for node type {node_type_key}.\nChoose a different "
  443. "availability zone or try manually creating an instance in "
  444. "your specified region to populate the list of subnets and "
  445. "trying this again."
  446. )
  447. elif _are_user_subnets_pruned(subnets):
  448. cli_logger.abort(
  449. f"MISMATCH between specified subnets and Availability Zones! "
  450. "The following Availability Zones were specified in the "
  451. f"`provider section`: {azs}.\n The following subnets for node "
  452. f"type `{node_type_key}` have no matching availability zone: "
  453. f"{list(_get_pruned_subnets(subnets))}."
  454. )
  455. # Use subnets in only one VPC, so that _configure_security_groups only
  456. # needs to create a security group in this one VPC. Otherwise, we'd need
  457. # to set up security groups in all of the user's VPCs and set up networking
  458. # rules to allow traffic between these groups.
  459. # See https://github.com/ray-project/ray/pull/14868.
  460. first_subnet_vpc_id = subnets[0].vpc_id
  461. subnets = [s.subnet_id for s in subnets if s.vpc_id == subnets[0].vpc_id]
  462. if _are_user_subnets_pruned(subnets):
  463. subnet_vpcs = {s.subnet_id: s.vpc_id for s in user_specified_subnets}
  464. cli_logger.abort(
  465. f"Subnets specified in more than one VPC for node type `{node_type_key}`! "
  466. f"Please ensure that all subnets share the same VPC and retry your "
  467. "request. Subnet VPCs: {}",
  468. subnet_vpcs,
  469. )
  470. return subnets, first_subnet_vpc_id
  471. def _configure_subnet(config):
  472. ec2 = _resource("ec2", config)
  473. # If head or worker security group is specified, filter down to subnets
  474. # belonging to the same VPC as the security group.
  475. sg_ids = []
  476. for node_type in config["available_node_types"].values():
  477. node_config = node_type["node_config"]
  478. sg_ids.extend(node_config.get("SecurityGroupIds", []))
  479. if sg_ids:
  480. vpc_id_of_sg = _get_vpc_id_of_sg(sg_ids, config)
  481. else:
  482. vpc_id_of_sg = None
  483. # map from node type key -> source of SubnetIds field
  484. subnet_src_info = {}
  485. _set_config_info(subnet_src=subnet_src_info)
  486. all_subnets = list(ec2.subnets.all())
  487. # separate node types with and without user-specified subnets
  488. node_types_subnets = []
  489. node_types_no_subnets = []
  490. for key, node_type in config["available_node_types"].items():
  491. if "SubnetIds" in node_type["node_config"]:
  492. node_types_subnets.append((key, node_type))
  493. else:
  494. node_types_no_subnets.append((key, node_type))
  495. vpc_id = None
  496. # iterate over node types with user-specified subnets first...
  497. for key, node_type in node_types_subnets:
  498. node_config = node_type["node_config"]
  499. user_subnets = _get_subnets_or_die(ec2, tuple(node_config["SubnetIds"]))
  500. subnet_ids, vpc_id = _usable_subnet_ids(
  501. user_subnets,
  502. all_subnets,
  503. azs=config["provider"].get("availability_zone"),
  504. vpc_id_of_sg=vpc_id_of_sg,
  505. use_internal_ips=config["provider"].get("use_internal_ips", False),
  506. node_type_key=key,
  507. )
  508. subnet_src_info[key] = "config"
  509. # lock-in a good VPC shared by the last set of user-specified subnets...
  510. if vpc_id and not vpc_id_of_sg:
  511. vpc_id_of_sg = vpc_id
  512. # iterate over node types without user-specified subnets last...
  513. for key, node_type in node_types_no_subnets:
  514. node_config = node_type["node_config"]
  515. subnet_ids, vpc_id = _usable_subnet_ids(
  516. None,
  517. all_subnets,
  518. azs=config["provider"].get("availability_zone"),
  519. vpc_id_of_sg=vpc_id_of_sg,
  520. use_internal_ips=config["provider"].get("use_internal_ips", False),
  521. node_type_key=key,
  522. )
  523. subnet_src_info[key] = "default"
  524. node_config["SubnetIds"] = subnet_ids
  525. return config
  526. def _get_vpc_id_of_sg(sg_ids: List[str], config: Dict[str, Any]) -> str:
  527. """Returns the VPC id of the security groups with the provided security
  528. group ids.
  529. Errors if the provided security groups belong to multiple VPCs.
  530. Errors if no security group with any of the provided ids is identified.
  531. """
  532. # sort security group IDs to support deterministic unit test stubbing
  533. sg_ids = sorted(set(sg_ids))
  534. ec2 = _resource("ec2", config)
  535. filters = [{"Name": "group-id", "Values": sg_ids}]
  536. security_groups = ec2.security_groups.filter(Filters=filters)
  537. vpc_ids = [sg.vpc_id for sg in security_groups]
  538. vpc_ids = list(set(vpc_ids))
  539. multiple_vpc_msg = (
  540. "All security groups specified in the cluster config "
  541. "should belong to the same VPC."
  542. )
  543. cli_logger.doassert(len(vpc_ids) <= 1, multiple_vpc_msg)
  544. assert len(vpc_ids) <= 1, multiple_vpc_msg
  545. no_sg_msg = (
  546. "Failed to detect a security group with id equal to any of "
  547. "the configured SecurityGroupIds."
  548. )
  549. cli_logger.doassert(len(vpc_ids) > 0, no_sg_msg)
  550. assert len(vpc_ids) > 0, no_sg_msg
  551. return vpc_ids[0]
  552. def _configure_security_group(config):
  553. # map from node type key -> source of SecurityGroupIds field
  554. security_group_info_src = {}
  555. _set_config_info(security_group_src=security_group_info_src)
  556. for node_type_key in config["available_node_types"]:
  557. security_group_info_src[node_type_key] = "config"
  558. node_types_to_configure = [
  559. node_type_key
  560. for node_type_key, node_type in config["available_node_types"].items()
  561. if "SecurityGroupIds" not in node_type["node_config"]
  562. ]
  563. if not node_types_to_configure:
  564. return config # have user-defined groups
  565. head_node_type = config["head_node_type"]
  566. if config["head_node_type"] in node_types_to_configure:
  567. # configure head node security group last for determinism
  568. # in tests
  569. node_types_to_configure.remove(head_node_type)
  570. node_types_to_configure.append(head_node_type)
  571. security_groups = _upsert_security_groups(config, node_types_to_configure)
  572. for node_type_key in node_types_to_configure:
  573. node_config = config["available_node_types"][node_type_key]["node_config"]
  574. sg = security_groups[node_type_key]
  575. node_config["SecurityGroupIds"] = [sg.id]
  576. security_group_info_src[node_type_key] = "default"
  577. return config
  578. def _check_ami(config):
  579. """Provide helpful message for missing ImageId for node configuration."""
  580. # map from node type key -> source of ImageId field
  581. ami_src_info = {key: "config" for key in config["available_node_types"]}
  582. _set_config_info(ami_src=ami_src_info)
  583. region = config["provider"]["region"]
  584. default_ami = DEFAULT_AMI.get(region)
  585. for key, node_type in config["available_node_types"].items():
  586. node_config = node_type["node_config"]
  587. node_ami = node_config.get("ImageId", "").lower()
  588. if node_ami in ["", "latest_dlami"]:
  589. if not default_ami:
  590. cli_logger.abort(
  591. f"Node type `{key}` has no ImageId in its node_config "
  592. f"and no default AMI is available for the region `{region}`. "
  593. "ImageId will need to be set manually in your cluster config."
  594. )
  595. else:
  596. node_config["ImageId"] = default_ami
  597. ami_src_info[key] = "dlami"
  598. def _upsert_security_groups(config, node_types):
  599. security_groups = _get_or_create_vpc_security_groups(config, node_types)
  600. _upsert_security_group_rules(config, security_groups)
  601. return security_groups
  602. def _get_or_create_vpc_security_groups(conf, node_types):
  603. # Figure out which VPC each node_type is in...
  604. ec2 = _resource("ec2", conf)
  605. node_type_to_vpc = {
  606. node_type: _get_vpc_id_or_die(
  607. ec2,
  608. conf["available_node_types"][node_type]["node_config"]["SubnetIds"][0],
  609. )
  610. for node_type in node_types
  611. }
  612. # Generate the name of the security group we're looking for...
  613. expected_sg_name = (
  614. conf["provider"]
  615. .get("security_group", {})
  616. .get("GroupName", SECURITY_GROUP_TEMPLATE.format(conf["cluster_name"]))
  617. )
  618. # Figure out which security groups with this name exist for each VPC...
  619. vpc_to_existing_sg = {
  620. sg.vpc_id: sg
  621. for sg in _get_security_groups(
  622. conf,
  623. node_type_to_vpc.values(),
  624. [expected_sg_name],
  625. )
  626. }
  627. # Lazily create any security group we're missing for each VPC...
  628. vpc_to_sg = LazyDefaultDict(
  629. partial(_create_security_group, conf, group_name=expected_sg_name),
  630. vpc_to_existing_sg,
  631. )
  632. # Then return a mapping from each node_type to its security group...
  633. return {
  634. node_type: vpc_to_sg[vpc_id] for node_type, vpc_id in node_type_to_vpc.items()
  635. }
  636. def _get_vpc_id_or_die(ec2, subnet_id: str):
  637. subnets = _get_subnets_or_die(ec2, (subnet_id,))
  638. cli_logger.doassert(
  639. len(subnets) == 1,
  640. f"Expected 1 subnet with ID `{subnet_id}` but found {len(subnets)}",
  641. )
  642. return subnets[0].vpc_id
  643. @lru_cache()
  644. def _get_subnets_or_die(ec2, subnet_ids: Tuple[str]):
  645. # Remove any duplicates as multiple interfaces are allowed to use same subnet
  646. subnet_ids = tuple(Counter(subnet_ids).keys())
  647. subnets = list(
  648. ec2.subnets.filter(Filters=[{"Name": "subnet-id", "Values": list(subnet_ids)}])
  649. )
  650. # TODO: better error message
  651. cli_logger.doassert(
  652. len(subnets) == len(subnet_ids), "Not all subnet IDs found: {}", subnet_ids
  653. )
  654. assert len(subnets) == len(subnet_ids), "Subnet ID not found: {}".format(subnet_ids)
  655. return subnets
  656. def _get_security_group(config, vpc_id, group_name):
  657. security_group = _get_security_groups(config, [vpc_id], [group_name])
  658. return None if not security_group else security_group[0]
  659. def _get_security_groups(config, vpc_ids, group_names):
  660. unique_vpc_ids = list(set(vpc_ids))
  661. unique_group_names = set(group_names)
  662. ec2 = _resource("ec2", config)
  663. existing_groups = list(
  664. ec2.security_groups.filter(
  665. Filters=[{"Name": "vpc-id", "Values": unique_vpc_ids}]
  666. )
  667. )
  668. filtered_groups = [
  669. sg for sg in existing_groups if sg.group_name in unique_group_names
  670. ]
  671. return filtered_groups
  672. def _create_security_group(config, vpc_id, group_name):
  673. client = _client("ec2", config)
  674. client.create_security_group(
  675. Description="Auto-created security group for Ray workers",
  676. GroupName=group_name,
  677. VpcId=vpc_id,
  678. TagSpecifications=[
  679. {
  680. "ResourceType": "security-group",
  681. "Tags": [
  682. {"Key": RAY, "Value": "true"},
  683. {"Key": "ray-cluster-name", "Value": config["cluster_name"]},
  684. ],
  685. },
  686. ],
  687. )
  688. security_group = _get_security_group(config, vpc_id, group_name)
  689. cli_logger.doassert(security_group, "Failed to create security group") # err msg
  690. cli_logger.verbose(
  691. "Created new security group {}",
  692. cf.bold(security_group.group_name),
  693. _tags=dict(id=security_group.id),
  694. )
  695. cli_logger.doassert(security_group, "Failed to create security group") # err msg
  696. assert security_group, "Failed to create security group"
  697. return security_group
  698. def _upsert_security_group_rules(conf, security_groups):
  699. sgids = {sg.id for sg in security_groups.values()}
  700. # Update sgids to include user-specified security groups.
  701. # This is necessary if the user specifies the head node type's security
  702. # groups but not the worker's, or vice-versa.
  703. for node_type in conf["available_node_types"]:
  704. sgids.update(
  705. conf["available_node_types"][node_type].get("SecurityGroupIds", [])
  706. )
  707. # sort security group items for deterministic inbound rule config order
  708. # (mainly supports more precise stub-based boto3 unit testing)
  709. for node_type, sg in sorted(security_groups.items()):
  710. sg = security_groups[node_type]
  711. if not sg.ip_permissions:
  712. _update_inbound_rules(sg, sgids, conf)
  713. def _update_inbound_rules(target_security_group, sgids, config):
  714. extended_rules = (
  715. config["provider"].get("security_group", {}).get("IpPermissions", [])
  716. )
  717. ip_permissions = _create_default_inbound_rules(sgids, extended_rules)
  718. target_security_group.authorize_ingress(IpPermissions=ip_permissions)
  719. def _create_default_inbound_rules(sgids, extended_rules=None):
  720. if extended_rules is None:
  721. extended_rules = []
  722. intracluster_rules = _create_default_intracluster_inbound_rules(sgids)
  723. ssh_rules = _create_default_ssh_inbound_rules()
  724. merged_rules = itertools.chain(
  725. intracluster_rules,
  726. ssh_rules,
  727. extended_rules,
  728. )
  729. return list(merged_rules)
  730. def _create_default_intracluster_inbound_rules(intracluster_sgids):
  731. return [
  732. {
  733. "FromPort": -1,
  734. "ToPort": -1,
  735. "IpProtocol": "-1",
  736. "UserIdGroupPairs": [
  737. {"GroupId": security_group_id}
  738. for security_group_id in sorted(intracluster_sgids)
  739. # sort security group IDs for deterministic IpPermission models
  740. # (mainly supports more precise stub-based boto3 unit testing)
  741. ],
  742. }
  743. ]
  744. def _create_default_ssh_inbound_rules():
  745. return [
  746. {
  747. "FromPort": 22,
  748. "ToPort": 22,
  749. "IpProtocol": "tcp",
  750. "IpRanges": [{"CidrIp": "0.0.0.0/0"}],
  751. }
  752. ]
  753. def _get_role(role_name, config):
  754. iam = _resource("iam", config)
  755. role = iam.Role(role_name)
  756. try:
  757. role.load()
  758. return role
  759. except botocore.exceptions.ClientError as exc:
  760. if exc.response.get("Error", {}).get("Code") == "NoSuchEntity":
  761. return None
  762. else:
  763. handle_boto_error(
  764. exc,
  765. "Failed to fetch IAM role data for {} from AWS.",
  766. cf.bold(role_name),
  767. )
  768. raise exc
  769. def _get_instance_profile(profile_name, config):
  770. iam = _resource("iam", config)
  771. profile = iam.InstanceProfile(profile_name)
  772. try:
  773. profile.load()
  774. return profile
  775. except botocore.exceptions.ClientError as exc:
  776. if exc.response.get("Error", {}).get("Code") == "NoSuchEntity":
  777. return None
  778. else:
  779. handle_boto_error(
  780. exc,
  781. "Failed to fetch IAM instance profile data for {} from AWS.",
  782. cf.bold(profile_name),
  783. )
  784. raise exc
  785. def _get_key(key_name, config):
  786. ec2 = _resource("ec2", config)
  787. try:
  788. for key in ec2.key_pairs.filter(
  789. Filters=[{"Name": "key-name", "Values": [key_name]}]
  790. ):
  791. if key.name == key_name:
  792. return key
  793. except botocore.exceptions.ClientError as exc:
  794. handle_boto_error(
  795. exc, "Failed to fetch EC2 key pair {} from AWS.", cf.bold(key_name)
  796. )
  797. raise exc
  798. def _configure_from_launch_template(config: Dict[str, Any]) -> Dict[str, Any]:
  799. """Merges any launch template data referenced by the node config of all
  800. available node type's into their parent node config. Any parameters
  801. specified in node config override the same parameters in the launch
  802. template, in compliance with the behavior of the ec2.create_instances
  803. API.
  804. Args:
  805. config (Dict[str, Any]): config to bootstrap
  806. Returns:
  807. config (Dict[str, Any]): The input config with all launch template
  808. data merged into the node config of all available node types. If no
  809. launch template data is found, then the config is returned
  810. unchanged.
  811. Raises:
  812. ValueError: When no launch template is found for the given launch template
  813. [name|id] and version, or more than one launch template is found.
  814. """
  815. # create a copy of the input config to modify
  816. config = copy.deepcopy(config)
  817. node_types = config["available_node_types"]
  818. # iterate over sorted node types to support deterministic unit test stubs
  819. for name, node_type in sorted(node_types.items()):
  820. node_types[name] = _configure_node_type_from_launch_template(config, node_type)
  821. return config
  822. def _configure_node_type_from_launch_template(
  823. config: Dict[str, Any], node_type: Dict[str, Any]
  824. ) -> Dict[str, Any]:
  825. """Merges any launch template data referenced by the given node type's
  826. node config into the parent node config. Any parameters specified in
  827. node config override the same parameters in the launch template.
  828. Args:
  829. config (Dict[str, Any]): config to bootstrap
  830. node_type (Dict[str, Any]): node type config to bootstrap
  831. Returns:
  832. node_type (Dict[str, Any]): The input config with all launch template
  833. data merged into the node config of the input node type. If no
  834. launch template data is found, then the config is returned
  835. unchanged.
  836. Raises:
  837. ValueError: When no launch template is found for the given launch template
  838. [name|id] and version, or more than one launch template is found.
  839. """
  840. # create a copy of the input config to modify
  841. node_type = copy.deepcopy(node_type)
  842. node_cfg = node_type["node_config"]
  843. if "LaunchTemplate" in node_cfg:
  844. node_type["node_config"] = _configure_node_cfg_from_launch_template(
  845. config, node_cfg
  846. )
  847. return node_type
  848. def _configure_node_cfg_from_launch_template(
  849. config: Dict[str, Any], node_cfg: Dict[str, Any]
  850. ) -> Dict[str, Any]:
  851. """Merges any launch template data referenced by the given node type's
  852. node config into the parent node config. Any parameters specified in
  853. node config override the same parameters in the launch template.
  854. Note that this merge is simply a bidirectional dictionary update, from
  855. the node config to the launch template data, and from the launch
  856. template data to the node config. Thus, the final result captures the
  857. relative complement of launch template data with respect to node config,
  858. and allows all subsequent config bootstrapping code paths to act as
  859. if the complement was explicitly specified in the user's node config. A
  860. deep merge of nested elements like tag specifications isn't required
  861. here, since the AWSNodeProvider's ec2.create_instances call will do this
  862. for us after it fetches the referenced launch template data.
  863. Args:
  864. config (Dict[str, Any]): config to bootstrap
  865. node_cfg (Dict[str, Any]): node config to bootstrap
  866. Returns:
  867. node_cfg (Dict[str, Any]): The input node config merged with all launch
  868. template data. If no launch template data is found, then the node
  869. config is returned unchanged.
  870. Raises:
  871. ValueError: When no launch template is found for the given launch template
  872. [name|id] and version, or more than one launch template is found.
  873. """
  874. # create a copy of the input config to modify
  875. node_cfg = copy.deepcopy(node_cfg)
  876. ec2 = _client("ec2", config)
  877. kwargs = copy.deepcopy(node_cfg["LaunchTemplate"])
  878. template_version = str(kwargs.pop("Version", "$Default"))
  879. # save the launch template version as a string to prevent errors from
  880. # passing an integer to ec2.create_instances in AWSNodeProvider
  881. node_cfg["LaunchTemplate"]["Version"] = template_version
  882. kwargs["Versions"] = [template_version] if template_version else []
  883. template = ec2.describe_launch_template_versions(**kwargs)
  884. lt_versions = template["LaunchTemplateVersions"]
  885. if len(lt_versions) != 1:
  886. raise ValueError(
  887. f"Expected to find 1 launch template but found " f"{len(lt_versions)}"
  888. )
  889. lt_data = template["LaunchTemplateVersions"][0]["LaunchTemplateData"]
  890. # override launch template parameters with explicit node config parameters
  891. lt_data.update(node_cfg)
  892. # copy all new launch template parameters back to node config
  893. node_cfg.update(lt_data)
  894. return node_cfg
  895. def _configure_from_network_interfaces(config: Dict[str, Any]) -> Dict[str, Any]:
  896. """Copies all network interface subnet and security group IDs up to their
  897. parent node config for each available node type.
  898. Args:
  899. config (Dict[str, Any]): config to bootstrap
  900. Returns:
  901. config (Dict[str, Any]): The input config with all network interface
  902. subnet and security group IDs copied into the node config of all
  903. available node types. If no network interfaces are found, then the
  904. config is returned unchanged.
  905. Raises:
  906. ValueError: If [1] subnet and security group IDs exist at both the
  907. node config and network interface levels, [2] any network interface
  908. doesn't have a subnet defined, or [3] any network interface doesn't
  909. have a security group defined.
  910. """
  911. # create a copy of the input config to modify
  912. config = copy.deepcopy(config)
  913. node_types = config["available_node_types"]
  914. for name, node_type in node_types.items():
  915. node_types[name] = _configure_node_type_from_network_interface(node_type)
  916. return config
  917. def _configure_node_type_from_network_interface(
  918. node_type: Dict[str, Any]
  919. ) -> Dict[str, Any]:
  920. """Copies all network interface subnet and security group IDs up to the
  921. parent node config for the given node type.
  922. Args:
  923. node_type (Dict[str, Any]): node type config to bootstrap
  924. Returns:
  925. node_type (Dict[str, Any]): The input config with all network interface
  926. subnet and security group IDs copied into the node config of the
  927. given node type. If no network interfaces are found, then the
  928. config is returned unchanged.
  929. Raises:
  930. ValueError: If [1] subnet and security group IDs exist at both the
  931. node config and network interface levels, [2] any network interface
  932. doesn't have a subnet defined, or [3] any network interface doesn't
  933. have a security group defined.
  934. """
  935. # create a copy of the input config to modify
  936. node_type = copy.deepcopy(node_type)
  937. node_cfg = node_type["node_config"]
  938. if "NetworkInterfaces" in node_cfg:
  939. node_type[
  940. "node_config"
  941. ] = _configure_subnets_and_groups_from_network_interfaces(node_cfg)
  942. return node_type
  943. def _configure_subnets_and_groups_from_network_interfaces(
  944. node_cfg: Dict[str, Any]
  945. ) -> Dict[str, Any]:
  946. """Copies all network interface subnet and security group IDs into their
  947. parent node config.
  948. Args:
  949. node_cfg (Dict[str, Any]): node config to bootstrap
  950. Returns:
  951. node_cfg (Dict[str, Any]): node config with all copied network
  952. interface subnet and security group IDs
  953. Raises:
  954. ValueError: If [1] subnet and security group IDs exist at both the
  955. node config and network interface levels, [2] any network interface
  956. doesn't have a subnet defined, or [3] any network interface doesn't
  957. have a security group defined.
  958. """
  959. # create a copy of the input config to modify
  960. node_cfg = copy.deepcopy(node_cfg)
  961. # If NetworkInterfaces are defined, SubnetId and SecurityGroupIds
  962. # can't be specified in the same node type config.
  963. conflict_keys = ["SubnetId", "SubnetIds", "SecurityGroupIds"]
  964. if any(conflict in node_cfg for conflict in conflict_keys):
  965. raise ValueError(
  966. "If NetworkInterfaces are defined, subnets and security groups "
  967. "must ONLY be given in each NetworkInterface."
  968. )
  969. subnets = _subnets_in_network_config(node_cfg)
  970. if not all(subnets):
  971. raise ValueError(
  972. "NetworkInterfaces are defined but at least one is missing a "
  973. "subnet. Please ensure all interfaces have a subnet assigned."
  974. )
  975. security_groups = _security_groups_in_network_config(node_cfg)
  976. if not all(security_groups):
  977. raise ValueError(
  978. "NetworkInterfaces are defined but at least one is missing a "
  979. "security group. Please ensure all interfaces have a security "
  980. "group assigned."
  981. )
  982. node_cfg["SubnetIds"] = subnets
  983. node_cfg["SecurityGroupIds"] = list(itertools.chain(*security_groups))
  984. return node_cfg
  985. def _subnets_in_network_config(config: Dict[str, Any]) -> List[str]:
  986. """
  987. Returns all subnet IDs found in the given node config's network interfaces.
  988. Args:
  989. config (Dict[str, Any]): node config
  990. Returns:
  991. subnet_ids (List[str]): List of subnet IDs for all network interfaces,
  992. or an empty list if no network interfaces are defined. An empty string
  993. is returned for each missing network interface subnet ID.
  994. """
  995. return [ni.get("SubnetId", "") for ni in config.get("NetworkInterfaces", [])]
  996. def _security_groups_in_network_config(config: Dict[str, Any]) -> List[List[str]]:
  997. """
  998. Returns all security group IDs found in the given node config's network
  999. interfaces.
  1000. Args:
  1001. config (Dict[str, Any]): node config
  1002. Returns:
  1003. security_group_ids (List[List[str]]): List of security group ID lists
  1004. for all network interfaces, or an empty list if no network interfaces
  1005. are defined. An empty list is returned for each missing network
  1006. interface security group list.
  1007. """
  1008. return [ni.get("Groups", []) for ni in config.get("NetworkInterfaces", [])]
  1009. def _client(name, config):
  1010. return _resource(name, config).meta.client
  1011. def _resource(name, config):
  1012. region = config["provider"]["region"]
  1013. aws_credentials = config["provider"].get("aws_credentials", {})
  1014. return resource_cache(name, region, **aws_credentials)