"""Implements the AWS environment.""" from __future__ import annotations import logging import os from wandb.sdk.launch.errors import LaunchError from wandb.util import get_module from ..utils import ARN_PARTITION_RE, S3_URI_RE, event_loop_thread_exec from .abstract import AbstractEnvironment boto3 = get_module( "boto3", required="AWS environment requires boto3 to be installed. Please install " "it with `pip install wandb[launch]`.", ) botocore = get_module( "botocore", required="AWS environment requires botocore to be installed. Please install " "it with `pip install wandb[launch]`.", ) _logger = logging.getLogger(__name__) class AwsEnvironment(AbstractEnvironment): """AWS environment.""" def __init__( self, region: str, access_key: str, secret_key: str, session_token: str, ) -> None: """Initialize the AWS environment. Arguments: region (str): The AWS region. Raises: LaunchError: If the AWS environment is not configured correctly. """ super().__init__() _logger.info(f"Initializing AWS environment in region {region}.") self._region = region self._access_key = access_key self._secret_key = secret_key self._session_token = session_token self._account = None self._partition = None @classmethod def from_default(cls, region: str | None = None) -> AwsEnvironment: """Create an AWS environment from the default AWS environment. Arguments: region (str, optional): The AWS region. verify (bool, optional): Whether to verify the AWS environment. Defaults to True. Returns: AwsEnvironment: The AWS environment. """ _logger.info("Creating AWS environment from default credentials.") try: session = boto3.Session() if hasattr(session, "region"): region = region or session.region region = region or os.environ.get("AWS_REGION") credentials = session.get_credentials() if not credentials: raise LaunchError( "Could not create AWS environment from default environment. Please verify that your AWS credentials are configured correctly." ) access_key = credentials.access_key secret_key = credentials.secret_key session_token = credentials.token except botocore.client.ClientError as e: raise LaunchError( f"Could not create AWS environment from default environment. Please verify that your AWS credentials are configured correctly. {e}" ) if not region: raise LaunchError( "Could not create AWS environment from default environment. Region not specified." ) return cls( region=region, access_key=access_key, secret_key=secret_key, session_token=session_token, ) @classmethod def from_config( cls, config: dict[str, str], ) -> AwsEnvironment: """Create an AWS environment from the default AWS environment. Arguments: config (dict): Configuration dictionary. verify (bool, optional): Whether to verify the AWS environment. Defaults to True. Returns: AwsEnvironment: The AWS environment. """ region = str(config.get("region", "")) if not region: raise LaunchError( "Could not create AWS environment from config. Region not specified." ) return cls.from_default( region=region, ) @property def region(self) -> str: """The AWS region.""" return self._region @region.setter def region(self, region: str) -> None: self._region = region async def get_partition(self) -> str: """Set the partition for the AWS environment.""" try: session = await self.get_session() client = await event_loop_thread_exec(session.client)("sts") get_caller_identity = event_loop_thread_exec(client.get_caller_identity) identity = await get_caller_identity() arn = identity.get("Arn") if not arn: raise LaunchError( "Could not set partition for AWS environment. ARN not found." ) matched_partition = ARN_PARTITION_RE.match(arn) if not matched_partition: raise LaunchError( f"Could not set partition for AWS environment. ARN {arn} is not valid." ) partition = matched_partition.group(1) return partition except botocore.exceptions.ClientError as e: raise LaunchError( f"Could not set partition for AWS environment. {e}" ) from e async def verify(self) -> None: """Verify that the AWS environment is configured correctly. Raises: LaunchError: If the AWS environment is not configured correctly. """ _logger.debug("Verifying AWS environment.") try: session = await self.get_session() client = await event_loop_thread_exec(session.client)("sts") get_caller_identity = event_loop_thread_exec(client.get_caller_identity) self._account = (await get_caller_identity()).get("Account") # TODO: log identity details from the response except botocore.exceptions.ClientError as e: raise LaunchError( f"Could not verify AWS environment. Please verify that your AWS credentials are configured correctly. {e}" ) from e async def get_session(self) -> boto3.Session: # type: ignore """Get an AWS session. Returns: boto3.Session: The AWS session. Raises: LaunchError: If the AWS session could not be created. """ _logger.debug(f"Creating AWS session in region {self._region}") try: session = event_loop_thread_exec(boto3.Session) return await session( region_name=self._region, aws_access_key_id=self._access_key, aws_secret_access_key=self._secret_key, aws_session_token=self._session_token, ) except botocore.exceptions.ClientError as e: raise LaunchError(f"Could not create AWS session. {e}") async def upload_file(self, source: str, destination: str) -> None: """Upload a file to s3 from local storage. The destination is a valid s3 URI, e.g. s3://bucket/key and will be used as a prefix for the uploaded file. Only the filename of the source is kept in the upload key. So if the source is "foo/bar" and the destination is "s3://bucket/key", the file "foo/bar" will be uploaded to "s3://bucket/key/bar". Arguments: source (str): The path to the file or directory. destination (str): The uri of the storage destination. This should be a valid s3 URI, e.g. s3://bucket/key. Raises: LaunchError: If the copy fails, the source path does not exist, or the destination is not a valid s3 URI, or the upload fails. """ _logger.debug(f"Uploading {source} to {destination}") _err_prefix = f"Error attempting to copy {source} to {destination}." if not os.path.isfile(source): raise LaunchError(f"{_err_prefix}: Source {source} does not exist.") match = S3_URI_RE.match(destination) if not match: raise LaunchError( f"{_err_prefix}: Destination {destination} is not a valid s3 URI." ) bucket = match.group(1) key = match.group(2).lstrip("/") if not key: key = "" session = await self.get_session() try: client = await event_loop_thread_exec(session.client)("s3") client.upload_file(source, bucket, key) except botocore.exceptions.ClientError as e: raise LaunchError( f"{_err_prefix}: botocore error attempting to copy {source} to {destination}. {e}" ) async def upload_dir(self, source: str, destination: str) -> None: """Upload a directory to s3 from local storage. The upload will place the contents of the source directory in the destination with the same directory structure. So if the source is "foo/bar" and the destination is "s3://bucket/key", the contents of "foo/bar" will be uploaded to "s3://bucket/key/bar". Arguments: source (str): The path to the file or directory. destination (str): The URI of the storage. recursive (bool, optional): If True, copy the directory recursively. Defaults to False. Raises: LaunchError: If the copy fails, the source path does not exist, or the destination is not a valid s3 URI. """ _logger.debug(f"Uploading {source} to {destination}") _err_prefix = f"Error attempting to copy {source} to {destination}." if not os.path.isdir(source): raise LaunchError(f"{_err_prefix}: Source {source} does not exist.") match = S3_URI_RE.match(destination) if not match: raise LaunchError( f"{_err_prefix}: Destination {destination} is not a valid s3 URI." ) bucket = match.group(1) key = match.group(2).lstrip("/") if not key: key = "" session = await self.get_session() try: client = await event_loop_thread_exec(session.client)("s3") for path, _, files in os.walk(source): for file in files: abs_path = os.path.join(path, file) key_path = ( abs_path.replace(source, "").replace("\\", "/").lstrip("/") ) client.upload_file( abs_path, bucket, key_path, ) except botocore.exceptions.ClientError as e: raise LaunchError( f"{_err_prefix}: botocore error attempting to copy {source} to {destination}. {e}" ) from e except Exception as e: raise LaunchError( f"{_err_prefix}: Unexpected error attempting to copy {source} to {destination}. {e}" ) from e async def verify_storage_uri(self, uri: str) -> None: """Verify that s3 storage is configured correctly. This will check that the bucket exists and that the credentials are configured correctly. Arguments: uri (str): The URI of the storage. Raises: LaunchError: If the storage is not configured correctly or the URI is not a valid s3 URI. Returns: None """ _logger.debug(f"Verifying storage {uri}") match = S3_URI_RE.match(uri) if not match: raise LaunchError( f"Failed to validate storage uri: {uri} is not a valid s3 URI." ) bucket = match.group(1) try: session = await self.get_session() client = await event_loop_thread_exec(session.client)("s3") client.head_bucket(Bucket=bucket) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] == "404": raise LaunchError( f"Could not verify AWS storage uri {uri}. Bucket {bucket} does not exist." ) if e.response["Error"]["Code"] == "403": raise LaunchError( f"Could not verify AWS storage uri {uri}. " "Bucket {bucket} is not accessible. Please check that this " "client is authenticated with permission to access the bucket." ) raise LaunchError( f"Failed to verify AWS storage uri {uri}. Response: {e.response} Please verify that your AWS credentials are configured correctly." )