| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- import logging
- import os
- import stat
- from ray.autoscaler._private.aliyun.utils import AcsClient
- # instance status
- PENDING = "Pending"
- RUNNING = "Running"
- STARTING = "Starting"
- STOPPING = "Stopping"
- STOPPED = "Stopped"
- logger = logging.getLogger(__name__)
- def bootstrap_aliyun(config):
- # print(config["provider"])
- # create vpc
- _get_or_create_vpc(config)
- # create security group id
- _get_or_create_security_group(config)
- # create vswitch
- _get_or_create_vswitch(config)
- # create key pair
- _get_or_import_key_pair(config)
- # print(config["provider"])
- return config
- def _client(config):
- return AcsClient(
- access_key=config["provider"].get("access_key"),
- access_key_secret=config["provider"].get("access_key_secret"),
- region_id=config["provider"]["region"],
- max_retries=1,
- )
- def _get_or_create_security_group(config):
- cli = _client(config)
- security_groups = cli.describe_security_groups(vpc_id=config["provider"]["vpc_id"])
- if security_groups is not None and len(security_groups) > 0:
- config["provider"]["security_group_id"] = security_groups[0]["SecurityGroupId"]
- return config
- security_group_id = cli.create_security_group(vpc_id=config["provider"]["vpc_id"])
- for rule in config["provider"].get("security_group_rule", {}):
- cli.authorize_security_group(
- security_group_id=security_group_id,
- port_range=rule["port_range"],
- source_cidr_ip=rule["source_cidr_ip"],
- ip_protocol=rule["ip_protocol"],
- )
- config["provider"]["security_group_id"] = security_group_id
- return
- def _get_or_create_vpc(config):
- cli = _client(config)
- vpcs = cli.describe_vpcs()
- if vpcs is not None and len(vpcs) > 0:
- config["provider"]["vpc_id"] = vpcs[0].get("VpcId")
- return
- vpc_id = cli.create_vpc()
- if vpc_id is not None:
- config["provider"]["vpc_id"] = vpc_id
- def _get_or_create_vswitch(config):
- cli = _client(config)
- vswitches = cli.describe_v_switches(vpc_id=config["provider"]["vpc_id"])
- if vswitches is not None and len(vswitches) > 0:
- config["provider"]["v_switch_id"] = vswitches[0].get("VSwitchId")
- return
- v_switch_id = cli.create_v_switch(
- vpc_id=config["provider"]["vpc_id"],
- zone_id=config["provider"]["zone_id"],
- cidr_block=config["provider"]["cidr_block"],
- )
- if v_switch_id is not None:
- config["provider"]["v_switch_id"] = v_switch_id
- def _get_or_import_key_pair(config):
- cli = _client(config)
- key_name = config["provider"].get("key_name", "ray")
- key_path = os.path.expanduser("~/.ssh/{}".format(key_name))
- keypairs = cli.describe_key_pairs(key_pair_name=key_name)
- if keypairs is not None and len(keypairs) > 0:
- if "ssh_private_key" not in config["auth"]:
- logger.info(
- "{} keypair exists, use {} as local ssh key".format(key_name, key_path)
- )
- config["auth"]["ssh_private_key"] = key_path
- else:
- if "ssh_private_key" not in config["auth"]:
- # create new keypair
- resp = cli.create_key_pair(key_pair_name=key_name)
- if resp is not None:
- with open(key_path, "w+") as f:
- f.write(resp.get("PrivateKeyBody"))
- os.chmod(key_path, stat.S_IRUSR)
- config["auth"]["ssh_private_key"] = key_path
- else:
- public_key_file = config["auth"]["ssh_private_key"] + ".pub"
- # create new keypair, from local file
- with open(public_key_file) as f:
- public_key = f.readline().strip("\n")
- cli.import_key_pair(key_pair_name=key_name, public_key_body=public_key)
- return
|