aws_environment.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. """Implements the AWS environment."""
  2. from __future__ import annotations
  3. import logging
  4. import os
  5. from wandb.sdk.launch.errors import LaunchError
  6. from wandb.util import get_module
  7. from ..utils import ARN_PARTITION_RE, S3_URI_RE, event_loop_thread_exec
  8. from .abstract import AbstractEnvironment
  9. boto3 = get_module(
  10. "boto3",
  11. required="AWS environment requires boto3 to be installed. Please install "
  12. "it with `pip install wandb[launch]`.",
  13. )
  14. botocore = get_module(
  15. "botocore",
  16. required="AWS environment requires botocore to be installed. Please install "
  17. "it with `pip install wandb[launch]`.",
  18. )
  19. _logger = logging.getLogger(__name__)
  20. class AwsEnvironment(AbstractEnvironment):
  21. """AWS environment."""
  22. def __init__(
  23. self,
  24. region: str,
  25. access_key: str,
  26. secret_key: str,
  27. session_token: str,
  28. ) -> None:
  29. """Initialize the AWS environment.
  30. Arguments:
  31. region (str): The AWS region.
  32. Raises:
  33. LaunchError: If the AWS environment is not configured correctly.
  34. """
  35. super().__init__()
  36. _logger.info(f"Initializing AWS environment in region {region}.")
  37. self._region = region
  38. self._access_key = access_key
  39. self._secret_key = secret_key
  40. self._session_token = session_token
  41. self._account = None
  42. self._partition = None
  43. @classmethod
  44. def from_default(cls, region: str | None = None) -> AwsEnvironment:
  45. """Create an AWS environment from the default AWS environment.
  46. Arguments:
  47. region (str, optional): The AWS region.
  48. verify (bool, optional): Whether to verify the AWS environment. Defaults to True.
  49. Returns:
  50. AwsEnvironment: The AWS environment.
  51. """
  52. _logger.info("Creating AWS environment from default credentials.")
  53. try:
  54. session = boto3.Session()
  55. if hasattr(session, "region"):
  56. region = region or session.region
  57. region = region or os.environ.get("AWS_REGION")
  58. credentials = session.get_credentials()
  59. if not credentials:
  60. raise LaunchError(
  61. "Could not create AWS environment from default environment. Please verify that your AWS credentials are configured correctly."
  62. )
  63. access_key = credentials.access_key
  64. secret_key = credentials.secret_key
  65. session_token = credentials.token
  66. except botocore.client.ClientError as e:
  67. raise LaunchError(
  68. f"Could not create AWS environment from default environment. Please verify that your AWS credentials are configured correctly. {e}"
  69. )
  70. if not region:
  71. raise LaunchError(
  72. "Could not create AWS environment from default environment. Region not specified."
  73. )
  74. return cls(
  75. region=region,
  76. access_key=access_key,
  77. secret_key=secret_key,
  78. session_token=session_token,
  79. )
  80. @classmethod
  81. def from_config(
  82. cls,
  83. config: dict[str, str],
  84. ) -> AwsEnvironment:
  85. """Create an AWS environment from the default AWS environment.
  86. Arguments:
  87. config (dict): Configuration dictionary.
  88. verify (bool, optional): Whether to verify the AWS environment. Defaults to True.
  89. Returns:
  90. AwsEnvironment: The AWS environment.
  91. """
  92. region = str(config.get("region", ""))
  93. if not region:
  94. raise LaunchError(
  95. "Could not create AWS environment from config. Region not specified."
  96. )
  97. return cls.from_default(
  98. region=region,
  99. )
  100. @property
  101. def region(self) -> str:
  102. """The AWS region."""
  103. return self._region
  104. @region.setter
  105. def region(self, region: str) -> None:
  106. self._region = region
  107. async def get_partition(self) -> str:
  108. """Set the partition for the AWS environment."""
  109. try:
  110. session = await self.get_session()
  111. client = await event_loop_thread_exec(session.client)("sts")
  112. get_caller_identity = event_loop_thread_exec(client.get_caller_identity)
  113. identity = await get_caller_identity()
  114. arn = identity.get("Arn")
  115. if not arn:
  116. raise LaunchError(
  117. "Could not set partition for AWS environment. ARN not found."
  118. )
  119. matched_partition = ARN_PARTITION_RE.match(arn)
  120. if not matched_partition:
  121. raise LaunchError(
  122. f"Could not set partition for AWS environment. ARN {arn} is not valid."
  123. )
  124. partition = matched_partition.group(1)
  125. return partition
  126. except botocore.exceptions.ClientError as e:
  127. raise LaunchError(
  128. f"Could not set partition for AWS environment. {e}"
  129. ) from e
  130. async def verify(self) -> None:
  131. """Verify that the AWS environment is configured correctly.
  132. Raises:
  133. LaunchError: If the AWS environment is not configured correctly.
  134. """
  135. _logger.debug("Verifying AWS environment.")
  136. try:
  137. session = await self.get_session()
  138. client = await event_loop_thread_exec(session.client)("sts")
  139. get_caller_identity = event_loop_thread_exec(client.get_caller_identity)
  140. self._account = (await get_caller_identity()).get("Account")
  141. # TODO: log identity details from the response
  142. except botocore.exceptions.ClientError as e:
  143. raise LaunchError(
  144. f"Could not verify AWS environment. Please verify that your AWS credentials are configured correctly. {e}"
  145. ) from e
  146. async def get_session(self) -> boto3.Session: # type: ignore
  147. """Get an AWS session.
  148. Returns:
  149. boto3.Session: The AWS session.
  150. Raises:
  151. LaunchError: If the AWS session could not be created.
  152. """
  153. _logger.debug(f"Creating AWS session in region {self._region}")
  154. try:
  155. session = event_loop_thread_exec(boto3.Session)
  156. return await session(
  157. region_name=self._region,
  158. aws_access_key_id=self._access_key,
  159. aws_secret_access_key=self._secret_key,
  160. aws_session_token=self._session_token,
  161. )
  162. except botocore.exceptions.ClientError as e:
  163. raise LaunchError(f"Could not create AWS session. {e}")
  164. async def upload_file(self, source: str, destination: str) -> None:
  165. """Upload a file to s3 from local storage.
  166. The destination is a valid s3 URI, e.g. s3://bucket/key and will
  167. be used as a prefix for the uploaded file. Only the filename of the source
  168. is kept in the upload key. So if the source is "foo/bar" and the
  169. destination is "s3://bucket/key", the file "foo/bar" will be uploaded
  170. to "s3://bucket/key/bar".
  171. Arguments:
  172. source (str): The path to the file or directory.
  173. destination (str): The uri of the storage destination. This should
  174. be a valid s3 URI, e.g. s3://bucket/key.
  175. Raises:
  176. LaunchError: If the copy fails, the source path does not exist, or the
  177. destination is not a valid s3 URI, or the upload fails.
  178. """
  179. _logger.debug(f"Uploading {source} to {destination}")
  180. _err_prefix = f"Error attempting to copy {source} to {destination}."
  181. if not os.path.isfile(source):
  182. raise LaunchError(f"{_err_prefix}: Source {source} does not exist.")
  183. match = S3_URI_RE.match(destination)
  184. if not match:
  185. raise LaunchError(
  186. f"{_err_prefix}: Destination {destination} is not a valid s3 URI."
  187. )
  188. bucket = match.group(1)
  189. key = match.group(2).lstrip("/")
  190. if not key:
  191. key = ""
  192. session = await self.get_session()
  193. try:
  194. client = await event_loop_thread_exec(session.client)("s3")
  195. client.upload_file(source, bucket, key)
  196. except botocore.exceptions.ClientError as e:
  197. raise LaunchError(
  198. f"{_err_prefix}: botocore error attempting to copy {source} to {destination}. {e}"
  199. )
  200. async def upload_dir(self, source: str, destination: str) -> None:
  201. """Upload a directory to s3 from local storage.
  202. The upload will place the contents of the source directory in the destination
  203. with the same directory structure. So if the source is "foo/bar" and the
  204. destination is "s3://bucket/key", the contents of "foo/bar" will be uploaded
  205. to "s3://bucket/key/bar".
  206. Arguments:
  207. source (str): The path to the file or directory.
  208. destination (str): The URI of the storage.
  209. recursive (bool, optional): If True, copy the directory recursively. Defaults to False.
  210. Raises:
  211. LaunchError: If the copy fails, the source path does not exist, or the
  212. destination is not a valid s3 URI.
  213. """
  214. _logger.debug(f"Uploading {source} to {destination}")
  215. _err_prefix = f"Error attempting to copy {source} to {destination}."
  216. if not os.path.isdir(source):
  217. raise LaunchError(f"{_err_prefix}: Source {source} does not exist.")
  218. match = S3_URI_RE.match(destination)
  219. if not match:
  220. raise LaunchError(
  221. f"{_err_prefix}: Destination {destination} is not a valid s3 URI."
  222. )
  223. bucket = match.group(1)
  224. key = match.group(2).lstrip("/")
  225. if not key:
  226. key = ""
  227. session = await self.get_session()
  228. try:
  229. client = await event_loop_thread_exec(session.client)("s3")
  230. for path, _, files in os.walk(source):
  231. for file in files:
  232. abs_path = os.path.join(path, file)
  233. key_path = (
  234. abs_path.replace(source, "").replace("\\", "/").lstrip("/")
  235. )
  236. client.upload_file(
  237. abs_path,
  238. bucket,
  239. key_path,
  240. )
  241. except botocore.exceptions.ClientError as e:
  242. raise LaunchError(
  243. f"{_err_prefix}: botocore error attempting to copy {source} to {destination}. {e}"
  244. ) from e
  245. except Exception as e:
  246. raise LaunchError(
  247. f"{_err_prefix}: Unexpected error attempting to copy {source} to {destination}. {e}"
  248. ) from e
  249. async def verify_storage_uri(self, uri: str) -> None:
  250. """Verify that s3 storage is configured correctly.
  251. This will check that the bucket exists and that the credentials are
  252. configured correctly.
  253. Arguments:
  254. uri (str): The URI of the storage.
  255. Raises:
  256. LaunchError: If the storage is not configured correctly or the URI is
  257. not a valid s3 URI.
  258. Returns:
  259. None
  260. """
  261. _logger.debug(f"Verifying storage {uri}")
  262. match = S3_URI_RE.match(uri)
  263. if not match:
  264. raise LaunchError(
  265. f"Failed to validate storage uri: {uri} is not a valid s3 URI."
  266. )
  267. bucket = match.group(1)
  268. try:
  269. session = await self.get_session()
  270. client = await event_loop_thread_exec(session.client)("s3")
  271. client.head_bucket(Bucket=bucket)
  272. except botocore.exceptions.ClientError as e:
  273. if e.response["Error"]["Code"] == "404":
  274. raise LaunchError(
  275. f"Could not verify AWS storage uri {uri}. Bucket {bucket} does not exist."
  276. )
  277. if e.response["Error"]["Code"] == "403":
  278. raise LaunchError(
  279. f"Could not verify AWS storage uri {uri}. "
  280. "Bucket {bucket} is not accessible. Please check that this "
  281. "client is authenticated with permission to access the bucket."
  282. )
  283. raise LaunchError(
  284. f"Failed to verify AWS storage uri {uri}. Response: {e.response} Please verify that your AWS credentials are configured correctly."
  285. )