config.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import logging
  2. import os
  3. import stat
  4. from ray.autoscaler._private.aliyun.utils import AcsClient
  5. # instance status
  6. PENDING = "Pending"
  7. RUNNING = "Running"
  8. STARTING = "Starting"
  9. STOPPING = "Stopping"
  10. STOPPED = "Stopped"
  11. logger = logging.getLogger(__name__)
  12. def bootstrap_aliyun(config):
  13. # print(config["provider"])
  14. # create vpc
  15. _get_or_create_vpc(config)
  16. # create security group id
  17. _get_or_create_security_group(config)
  18. # create vswitch
  19. _get_or_create_vswitch(config)
  20. # create key pair
  21. _get_or_import_key_pair(config)
  22. # print(config["provider"])
  23. return config
  24. def _client(config):
  25. return AcsClient(
  26. access_key=config["provider"].get("access_key"),
  27. access_key_secret=config["provider"].get("access_key_secret"),
  28. region_id=config["provider"]["region"],
  29. max_retries=1,
  30. )
  31. def _get_or_create_security_group(config):
  32. cli = _client(config)
  33. security_groups = cli.describe_security_groups(vpc_id=config["provider"]["vpc_id"])
  34. if security_groups is not None and len(security_groups) > 0:
  35. config["provider"]["security_group_id"] = security_groups[0]["SecurityGroupId"]
  36. return config
  37. security_group_id = cli.create_security_group(vpc_id=config["provider"]["vpc_id"])
  38. for rule in config["provider"].get("security_group_rule", {}):
  39. cli.authorize_security_group(
  40. security_group_id=security_group_id,
  41. port_range=rule["port_range"],
  42. source_cidr_ip=rule["source_cidr_ip"],
  43. ip_protocol=rule["ip_protocol"],
  44. )
  45. config["provider"]["security_group_id"] = security_group_id
  46. return
  47. def _get_or_create_vpc(config):
  48. cli = _client(config)
  49. vpcs = cli.describe_vpcs()
  50. if vpcs is not None and len(vpcs) > 0:
  51. config["provider"]["vpc_id"] = vpcs[0].get("VpcId")
  52. return
  53. vpc_id = cli.create_vpc()
  54. if vpc_id is not None:
  55. config["provider"]["vpc_id"] = vpc_id
  56. def _get_or_create_vswitch(config):
  57. cli = _client(config)
  58. vswitches = cli.describe_v_switches(vpc_id=config["provider"]["vpc_id"])
  59. if vswitches is not None and len(vswitches) > 0:
  60. config["provider"]["v_switch_id"] = vswitches[0].get("VSwitchId")
  61. return
  62. v_switch_id = cli.create_v_switch(
  63. vpc_id=config["provider"]["vpc_id"],
  64. zone_id=config["provider"]["zone_id"],
  65. cidr_block=config["provider"]["cidr_block"],
  66. )
  67. if v_switch_id is not None:
  68. config["provider"]["v_switch_id"] = v_switch_id
  69. def _get_or_import_key_pair(config):
  70. cli = _client(config)
  71. key_name = config["provider"].get("key_name", "ray")
  72. key_path = os.path.expanduser("~/.ssh/{}".format(key_name))
  73. keypairs = cli.describe_key_pairs(key_pair_name=key_name)
  74. if keypairs is not None and len(keypairs) > 0:
  75. if "ssh_private_key" not in config["auth"]:
  76. logger.info(
  77. "{} keypair exists, use {} as local ssh key".format(key_name, key_path)
  78. )
  79. config["auth"]["ssh_private_key"] = key_path
  80. else:
  81. if "ssh_private_key" not in config["auth"]:
  82. # create new keypair
  83. resp = cli.create_key_pair(key_pair_name=key_name)
  84. if resp is not None:
  85. with open(key_path, "w+") as f:
  86. f.write(resp.get("PrivateKeyBody"))
  87. os.chmod(key_path, stat.S_IRUSR)
  88. config["auth"]["ssh_private_key"] = key_path
  89. else:
  90. public_key_file = config["auth"]["ssh_private_key"] + ".pub"
  91. # create new keypair, from local file
  92. with open(public_key_file) as f:
  93. public_key = f.readline().strip("\n")
  94. cli.import_key_pair(key_pair_name=key_name, public_key_body=public_key)
  95. return