"""Implementation of the SageMakerRunner class.""" from __future__ import annotations import asyncio import logging from typing import Any, cast if False: import boto3 # type: ignore import wandb from wandb.apis.internal import Api from wandb.sdk.launch.environment.aws_environment import AwsEnvironment from wandb.sdk.launch.errors import LaunchError from .._project_spec import EntryPoint, LaunchProject from ..registry.abstract import AbstractRegistry from ..utils import ( LOG_PREFIX, MAX_ENV_LENGTHS, PROJECT_SYNCHRONOUS, event_loop_thread_exec, to_camel_case, ) from .abstract import AbstractRun, AbstractRunner, Status _logger = logging.getLogger(__name__) class SagemakerSubmittedRun(AbstractRun): """Instance of ``AbstractRun`` corresponding to a subprocess launched to run an entry point command on aws sagemaker.""" def __init__( self, training_job_name: str, client: boto3.Client, log_client: boto3.Client | None = None, ) -> None: super().__init__() self.client = client self.log_client = log_client self.training_job_name = training_job_name self._status = Status("running") @property def id(self) -> str: return f"sagemaker-{self.training_job_name}" async def get_logs(self) -> str | None: if self.log_client is None: return None try: describe_log_streams = event_loop_thread_exec( self.log_client.describe_log_streams ) describe_res = await describe_log_streams( logGroupName="/aws/sagemaker/TrainingJobs", logStreamNamePrefix=self.training_job_name, ) if len(describe_res["logStreams"]) == 0: wandb.termwarn( f"Failed to get logs for training job: {self.training_job_name}" ) return None log_name = describe_res["logStreams"][0]["logStreamName"] get_log_events = event_loop_thread_exec(self.log_client.get_log_events) res = await get_log_events( logGroupName="/aws/sagemaker/TrainingJobs", logStreamName=log_name, ) assert "events" in res return "\n".join( [f"{event['timestamp']}:{event['message']}" for event in res["events"]] ) except self.log_client.exceptions.ResourceNotFoundException: wandb.termwarn( f"Failed to get logs for training job: {self.training_job_name}" ) return None except Exception as e: wandb.termwarn( f"Failed to handle logs for training job: {self.training_job_name} with error {str(e)}" ) return None async def wait(self) -> bool: while True: status_state = (await self.get_status()).state wandb.termlog( f"{LOG_PREFIX}Training job {self.training_job_name} status: {status_state}" ) if status_state in ["stopped", "failed", "finished"]: break await asyncio.sleep(5) return status_state == "finished" async def cancel(self) -> None: # Interrupt child process if it hasn't already exited status = await self.get_status() if status.state == "running": self.client.stop_training_job(TrainingJobName=self.training_job_name) await self.wait() async def get_status(self) -> Status: describe_training_job = event_loop_thread_exec( self.client.describe_training_job ) job_status = ( await describe_training_job(TrainingJobName=self.training_job_name) )["TrainingJobStatus"] if job_status == "Completed" or job_status == "Stopped": self._status = Status("finished") elif job_status == "Failed": self._status = Status("failed") elif job_status == "Stopping": self._status = Status("stopping") elif job_status == "InProgress": self._status = Status("running") return self._status class SageMakerRunner(AbstractRunner): """Runner class, uses a project to create a SagemakerSubmittedRun.""" def __init__( self, api: Api, backend_config: dict[str, Any], environment: AwsEnvironment, registry: AbstractRegistry, ) -> None: """Initialize the SagemakerRunner. Arguments: api (Api): The API instance. backend_config (Dict[str, Any]): The backend configuration. environment (AwsEnvironment): The AWS environment. Raises: LaunchError: If the runner cannot be initialized. """ super().__init__(api, backend_config) self.environment = environment self.registry = registry async def run( self, launch_project: LaunchProject, image_uri: str, ) -> AbstractRun | None: """Run a project on Amazon Sagemaker. Arguments: launch_project (LaunchProject): The project to run. Returns: Optional[AbstractRun]: The run instance. Raises: LaunchError: If the launch is unsuccessful. """ _logger.info("using AWSSagemakerRunner") given_sagemaker_args = launch_project.resource_args.get("sagemaker") if given_sagemaker_args is None: raise LaunchError( "No sagemaker args specified. Specify sagemaker args in resource_args" ) default_output_path = self.backend_config.get("runner", {}).get( "s3_output_path" ) if default_output_path is not None and not default_output_path.startswith( "s3://" ): default_output_path = f"s3://{default_output_path}" session = await self.environment.get_session() client = await event_loop_thread_exec(session.client)("sts") caller_id = client.get_caller_identity() account_id = caller_id["Account"] _logger.info(f"Using account ID {account_id}") partition = await self.environment.get_partition() role_arn = get_role_arn( given_sagemaker_args, self.backend_config, account_id, partition ) # Create a sagemaker client to launch the job. sagemaker_client = session.client("sagemaker") log_client = None try: log_client = session.client("logs") except Exception as e: wandb.termwarn( f"Failed to connect to cloudwatch logs with error {str(e)}, logs will not be available" ) # if the user provided the image they want to use, use that, but warn it won't have swappable artifacts if ( given_sagemaker_args.get("AlgorithmSpecification", {}).get("TrainingImage") is not None ): sagemaker_args = build_sagemaker_args( launch_project, self._api, role_arn, launch_project.override_entrypoint, launch_project.override_args, MAX_ENV_LENGTHS[self.__class__.__name__], given_sagemaker_args.get("AlgorithmSpecification", {}).get( "TrainingImage" ), default_output_path, ) _logger.info( f"Launching sagemaker job on user supplied image with args: {sagemaker_args}" ) run = await launch_sagemaker_job( launch_project, sagemaker_args, sagemaker_client, log_client ) if self.backend_config[PROJECT_SYNCHRONOUS]: await run.wait() return run _logger.info("Connecting to sagemaker client") entry_point = ( launch_project.override_entrypoint or launch_project.get_job_entry_point() ) command_args = [] if entry_point is not None: command_args += entry_point.command command_args += launch_project.override_args if command_args: command_str = " ".join(command_args) wandb.termlog( f"{LOG_PREFIX}Launching run on sagemaker with entrypoint: {command_str}" ) else: wandb.termlog( f"{LOG_PREFIX}Launching run on sagemaker with user-provided entrypoint in image" ) sagemaker_args = build_sagemaker_args( launch_project, self._api, role_arn, entry_point, launch_project.override_args, MAX_ENV_LENGTHS[self.__class__.__name__], image_uri, default_output_path, ) _logger.info(f"Launching sagemaker job with args: {sagemaker_args}") run = await launch_sagemaker_job( launch_project, sagemaker_args, sagemaker_client, log_client ) if self.backend_config[PROJECT_SYNCHRONOUS]: await run.wait() return run def merge_image_uri_with_algorithm_specification( algorithm_specification: dict[str, Any] | None, image_uri: str | None, entrypoint_command: list[str], args: list[str] | None, ) -> dict[str, Any]: """Create an AWS AlgorithmSpecification. AWS Sagemaker algorithms require a training image and an input mode. If the user does not specify the specification themselves, define the spec minimally using these two fields. Otherwise, if they specify the AlgorithmSpecification set the training image if it is not set. """ if algorithm_specification is None: algorithm_specification = { "TrainingImage": image_uri, "TrainingInputMode": "File", } else: if image_uri: algorithm_specification["TrainingImage"] = image_uri if entrypoint_command: algorithm_specification["ContainerEntrypoint"] = entrypoint_command if args: algorithm_specification["ContainerArguments"] = args if algorithm_specification["TrainingImage"] is None: raise LaunchError("Failed determine tag for training image") return algorithm_specification def build_sagemaker_args( launch_project: LaunchProject, api: Api, role_arn: str, entry_point: EntryPoint | None, args: list[str] | None, max_env_length: int, image_uri: str, default_output_path: str | None = None, ) -> dict[str, Any]: sagemaker_args: dict[str, Any] = {} resource_args = launch_project.fill_macros(image_uri) given_sagemaker_args: dict[str, Any] | None = resource_args.get("sagemaker") if given_sagemaker_args is None: raise LaunchError( "No sagemaker args specified. Specify sagemaker args in resource_args" ) if ( given_sagemaker_args.get("OutputDataConfig") is None and default_output_path is not None ): sagemaker_args["OutputDataConfig"] = {"S3OutputPath": default_output_path} else: sagemaker_args["OutputDataConfig"] = given_sagemaker_args.get( "OutputDataConfig" ) if sagemaker_args.get("OutputDataConfig") is None: raise LaunchError( "Sagemaker launcher requires an OutputDataConfig Sagemaker resource argument" ) training_job_name = cast( str, (given_sagemaker_args.get("TrainingJobName") or launch_project.run_id) ) sagemaker_args["TrainingJobName"] = training_job_name entry_cmd = entry_point.command if entry_point else [] sagemaker_args["AlgorithmSpecification"] = ( merge_image_uri_with_algorithm_specification( given_sagemaker_args.get( "AlgorithmSpecification", given_sagemaker_args.get("algorithm_specification"), ), image_uri, entry_cmd, args, ) ) sagemaker_args["RoleArn"] = role_arn camel_case_args = { to_camel_case(key): item for key, item in given_sagemaker_args.items() } sagemaker_args = { **camel_case_args, **sagemaker_args, } if sagemaker_args.get("ResourceConfig") is None: raise LaunchError( "Sagemaker launcher requires a ResourceConfig resource argument" ) if sagemaker_args.get("StoppingCondition") is None: raise LaunchError( "Sagemaker launcher requires a StoppingCondition resource argument" ) given_env = given_sagemaker_args.get( "Environment", sagemaker_args.get("environment", {}) ) calced_env = launch_project.get_env_vars_dict(api, max_env_length) total_env = {**calced_env, **given_env} sagemaker_args["Environment"] = total_env # Add wandb tag tags = sagemaker_args.get("Tags", []) tags.append({"Key": "WandbRunId", "Value": launch_project.run_id}) sagemaker_args["Tags"] = tags # remove args that were passed in for launch but not passed to sagemaker sagemaker_args.pop("EcrRepoName", None) sagemaker_args.pop("region", None) sagemaker_args.pop("profile", None) # clear the args that are None so they are not passed filtered_args = {k: v for k, v in sagemaker_args.items() if v is not None} return filtered_args async def launch_sagemaker_job( launch_project: LaunchProject, sagemaker_args: dict[str, Any], sagemaker_client: boto3.Client, log_client: boto3.Client | None = None, ) -> SagemakerSubmittedRun: training_job_name = sagemaker_args.get("TrainingJobName") or launch_project.run_id create_training_job = event_loop_thread_exec(sagemaker_client.create_training_job) resp = await create_training_job(**sagemaker_args) if resp.get("TrainingJobArn") is None: raise LaunchError("Failed to create training job when submitting to SageMaker") run = SagemakerSubmittedRun(training_job_name, sagemaker_client, log_client) wandb.termlog( f"{LOG_PREFIX}Run job submitted with arn: {resp.get('TrainingJobArn')}" ) url = f"https://{sagemaker_client.meta.region_name}.console.aws.amazon.com/sagemaker/home?region={sagemaker_client.meta.region_name}#/jobs/{training_job_name}" wandb.termlog(f"{LOG_PREFIX}See training job status at: {url}") return run def get_role_arn( sagemaker_args: dict[str, Any], backend_config: dict[str, Any], account_id: str, partition: str, ) -> str: """Get the role arn from the sagemaker args or the backend config.""" role_arn = sagemaker_args.get("RoleArn") or sagemaker_args.get("role_arn") if role_arn is None: role_arn = backend_config.get("runner", {}).get("role_arn") if role_arn is None or not isinstance(role_arn, str): raise LaunchError( "AWS sagemaker require a string RoleArn set this by adding a `RoleArn` key to the sagemaker" "field of resource_args" ) if role_arn.startswith(f"arn:{partition}:iam::"): return role_arn # type: ignore return f"arn:{partition}:iam::{account_id}:role/{role_arn}"