| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426 |
- """Implementation of the SageMakerRunner class."""
- from __future__ import annotations
- import asyncio
- import logging
- from typing import Any, cast
- if False:
- import boto3 # type: ignore
- import wandb
- from wandb.apis.internal import Api
- from wandb.sdk.launch.environment.aws_environment import AwsEnvironment
- from wandb.sdk.launch.errors import LaunchError
- from .._project_spec import EntryPoint, LaunchProject
- from ..registry.abstract import AbstractRegistry
- from ..utils import (
- LOG_PREFIX,
- MAX_ENV_LENGTHS,
- PROJECT_SYNCHRONOUS,
- event_loop_thread_exec,
- to_camel_case,
- )
- from .abstract import AbstractRun, AbstractRunner, Status
- _logger = logging.getLogger(__name__)
- class SagemakerSubmittedRun(AbstractRun):
- """Instance of ``AbstractRun`` corresponding to a subprocess launched to run an entry point command on aws sagemaker."""
- def __init__(
- self,
- training_job_name: str,
- client: boto3.Client,
- log_client: boto3.Client | None = None,
- ) -> None:
- super().__init__()
- self.client = client
- self.log_client = log_client
- self.training_job_name = training_job_name
- self._status = Status("running")
- @property
- def id(self) -> str:
- return f"sagemaker-{self.training_job_name}"
- async def get_logs(self) -> str | None:
- if self.log_client is None:
- return None
- try:
- describe_log_streams = event_loop_thread_exec(
- self.log_client.describe_log_streams
- )
- describe_res = await describe_log_streams(
- logGroupName="/aws/sagemaker/TrainingJobs",
- logStreamNamePrefix=self.training_job_name,
- )
- if len(describe_res["logStreams"]) == 0:
- wandb.termwarn(
- f"Failed to get logs for training job: {self.training_job_name}"
- )
- return None
- log_name = describe_res["logStreams"][0]["logStreamName"]
- get_log_events = event_loop_thread_exec(self.log_client.get_log_events)
- res = await get_log_events(
- logGroupName="/aws/sagemaker/TrainingJobs",
- logStreamName=log_name,
- )
- assert "events" in res
- return "\n".join(
- [f"{event['timestamp']}:{event['message']}" for event in res["events"]]
- )
- except self.log_client.exceptions.ResourceNotFoundException:
- wandb.termwarn(
- f"Failed to get logs for training job: {self.training_job_name}"
- )
- return None
- except Exception as e:
- wandb.termwarn(
- f"Failed to handle logs for training job: {self.training_job_name} with error {str(e)}"
- )
- return None
- async def wait(self) -> bool:
- while True:
- status_state = (await self.get_status()).state
- wandb.termlog(
- f"{LOG_PREFIX}Training job {self.training_job_name} status: {status_state}"
- )
- if status_state in ["stopped", "failed", "finished"]:
- break
- await asyncio.sleep(5)
- return status_state == "finished"
- async def cancel(self) -> None:
- # Interrupt child process if it hasn't already exited
- status = await self.get_status()
- if status.state == "running":
- self.client.stop_training_job(TrainingJobName=self.training_job_name)
- await self.wait()
- async def get_status(self) -> Status:
- describe_training_job = event_loop_thread_exec(
- self.client.describe_training_job
- )
- job_status = (
- await describe_training_job(TrainingJobName=self.training_job_name)
- )["TrainingJobStatus"]
- if job_status == "Completed" or job_status == "Stopped":
- self._status = Status("finished")
- elif job_status == "Failed":
- self._status = Status("failed")
- elif job_status == "Stopping":
- self._status = Status("stopping")
- elif job_status == "InProgress":
- self._status = Status("running")
- return self._status
- class SageMakerRunner(AbstractRunner):
- """Runner class, uses a project to create a SagemakerSubmittedRun."""
- def __init__(
- self,
- api: Api,
- backend_config: dict[str, Any],
- environment: AwsEnvironment,
- registry: AbstractRegistry,
- ) -> None:
- """Initialize the SagemakerRunner.
- Arguments:
- api (Api): The API instance.
- backend_config (Dict[str, Any]): The backend configuration.
- environment (AwsEnvironment): The AWS environment.
- Raises:
- LaunchError: If the runner cannot be initialized.
- """
- super().__init__(api, backend_config)
- self.environment = environment
- self.registry = registry
- async def run(
- self,
- launch_project: LaunchProject,
- image_uri: str,
- ) -> AbstractRun | None:
- """Run a project on Amazon Sagemaker.
- Arguments:
- launch_project (LaunchProject): The project to run.
- Returns:
- Optional[AbstractRun]: The run instance.
- Raises:
- LaunchError: If the launch is unsuccessful.
- """
- _logger.info("using AWSSagemakerRunner")
- given_sagemaker_args = launch_project.resource_args.get("sagemaker")
- if given_sagemaker_args is None:
- raise LaunchError(
- "No sagemaker args specified. Specify sagemaker args in resource_args"
- )
- default_output_path = self.backend_config.get("runner", {}).get(
- "s3_output_path"
- )
- if default_output_path is not None and not default_output_path.startswith(
- "s3://"
- ):
- default_output_path = f"s3://{default_output_path}"
- session = await self.environment.get_session()
- client = await event_loop_thread_exec(session.client)("sts")
- caller_id = client.get_caller_identity()
- account_id = caller_id["Account"]
- _logger.info(f"Using account ID {account_id}")
- partition = await self.environment.get_partition()
- role_arn = get_role_arn(
- given_sagemaker_args, self.backend_config, account_id, partition
- )
- # Create a sagemaker client to launch the job.
- sagemaker_client = session.client("sagemaker")
- log_client = None
- try:
- log_client = session.client("logs")
- except Exception as e:
- wandb.termwarn(
- f"Failed to connect to cloudwatch logs with error {str(e)}, logs will not be available"
- )
- # if the user provided the image they want to use, use that, but warn it won't have swappable artifacts
- if (
- given_sagemaker_args.get("AlgorithmSpecification", {}).get("TrainingImage")
- is not None
- ):
- sagemaker_args = build_sagemaker_args(
- launch_project,
- self._api,
- role_arn,
- launch_project.override_entrypoint,
- launch_project.override_args,
- MAX_ENV_LENGTHS[self.__class__.__name__],
- given_sagemaker_args.get("AlgorithmSpecification", {}).get(
- "TrainingImage"
- ),
- default_output_path,
- )
- _logger.info(
- f"Launching sagemaker job on user supplied image with args: {sagemaker_args}"
- )
- run = await launch_sagemaker_job(
- launch_project, sagemaker_args, sagemaker_client, log_client
- )
- if self.backend_config[PROJECT_SYNCHRONOUS]:
- await run.wait()
- return run
- _logger.info("Connecting to sagemaker client")
- entry_point = (
- launch_project.override_entrypoint or launch_project.get_job_entry_point()
- )
- command_args = []
- if entry_point is not None:
- command_args += entry_point.command
- command_args += launch_project.override_args
- if command_args:
- command_str = " ".join(command_args)
- wandb.termlog(
- f"{LOG_PREFIX}Launching run on sagemaker with entrypoint: {command_str}"
- )
- else:
- wandb.termlog(
- f"{LOG_PREFIX}Launching run on sagemaker with user-provided entrypoint in image"
- )
- sagemaker_args = build_sagemaker_args(
- launch_project,
- self._api,
- role_arn,
- entry_point,
- launch_project.override_args,
- MAX_ENV_LENGTHS[self.__class__.__name__],
- image_uri,
- default_output_path,
- )
- _logger.info(f"Launching sagemaker job with args: {sagemaker_args}")
- run = await launch_sagemaker_job(
- launch_project, sagemaker_args, sagemaker_client, log_client
- )
- if self.backend_config[PROJECT_SYNCHRONOUS]:
- await run.wait()
- return run
- def merge_image_uri_with_algorithm_specification(
- algorithm_specification: dict[str, Any] | None,
- image_uri: str | None,
- entrypoint_command: list[str],
- args: list[str] | None,
- ) -> dict[str, Any]:
- """Create an AWS AlgorithmSpecification.
- AWS Sagemaker algorithms require a training image and an input mode. If the user
- does not specify the specification themselves, define the spec minimally using these
- two fields. Otherwise, if they specify the AlgorithmSpecification set the training
- image if it is not set.
- """
- if algorithm_specification is None:
- algorithm_specification = {
- "TrainingImage": image_uri,
- "TrainingInputMode": "File",
- }
- else:
- if image_uri:
- algorithm_specification["TrainingImage"] = image_uri
- if entrypoint_command:
- algorithm_specification["ContainerEntrypoint"] = entrypoint_command
- if args:
- algorithm_specification["ContainerArguments"] = args
- if algorithm_specification["TrainingImage"] is None:
- raise LaunchError("Failed determine tag for training image")
- return algorithm_specification
- def build_sagemaker_args(
- launch_project: LaunchProject,
- api: Api,
- role_arn: str,
- entry_point: EntryPoint | None,
- args: list[str] | None,
- max_env_length: int,
- image_uri: str,
- default_output_path: str | None = None,
- ) -> dict[str, Any]:
- sagemaker_args: dict[str, Any] = {}
- resource_args = launch_project.fill_macros(image_uri)
- given_sagemaker_args: dict[str, Any] | None = resource_args.get("sagemaker")
- if given_sagemaker_args is None:
- raise LaunchError(
- "No sagemaker args specified. Specify sagemaker args in resource_args"
- )
- if (
- given_sagemaker_args.get("OutputDataConfig") is None
- and default_output_path is not None
- ):
- sagemaker_args["OutputDataConfig"] = {"S3OutputPath": default_output_path}
- else:
- sagemaker_args["OutputDataConfig"] = given_sagemaker_args.get(
- "OutputDataConfig"
- )
- if sagemaker_args.get("OutputDataConfig") is None:
- raise LaunchError(
- "Sagemaker launcher requires an OutputDataConfig Sagemaker resource argument"
- )
- training_job_name = cast(
- str, (given_sagemaker_args.get("TrainingJobName") or launch_project.run_id)
- )
- sagemaker_args["TrainingJobName"] = training_job_name
- entry_cmd = entry_point.command if entry_point else []
- sagemaker_args["AlgorithmSpecification"] = (
- merge_image_uri_with_algorithm_specification(
- given_sagemaker_args.get(
- "AlgorithmSpecification",
- given_sagemaker_args.get("algorithm_specification"),
- ),
- image_uri,
- entry_cmd,
- args,
- )
- )
- sagemaker_args["RoleArn"] = role_arn
- camel_case_args = {
- to_camel_case(key): item for key, item in given_sagemaker_args.items()
- }
- sagemaker_args = {
- **camel_case_args,
- **sagemaker_args,
- }
- if sagemaker_args.get("ResourceConfig") is None:
- raise LaunchError(
- "Sagemaker launcher requires a ResourceConfig resource argument"
- )
- if sagemaker_args.get("StoppingCondition") is None:
- raise LaunchError(
- "Sagemaker launcher requires a StoppingCondition resource argument"
- )
- given_env = given_sagemaker_args.get(
- "Environment", sagemaker_args.get("environment", {})
- )
- calced_env = launch_project.get_env_vars_dict(api, max_env_length)
- total_env = {**calced_env, **given_env}
- sagemaker_args["Environment"] = total_env
- # Add wandb tag
- tags = sagemaker_args.get("Tags", [])
- tags.append({"Key": "WandbRunId", "Value": launch_project.run_id})
- sagemaker_args["Tags"] = tags
- # remove args that were passed in for launch but not passed to sagemaker
- sagemaker_args.pop("EcrRepoName", None)
- sagemaker_args.pop("region", None)
- sagemaker_args.pop("profile", None)
- # clear the args that are None so they are not passed
- filtered_args = {k: v for k, v in sagemaker_args.items() if v is not None}
- return filtered_args
- async def launch_sagemaker_job(
- launch_project: LaunchProject,
- sagemaker_args: dict[str, Any],
- sagemaker_client: boto3.Client,
- log_client: boto3.Client | None = None,
- ) -> SagemakerSubmittedRun:
- training_job_name = sagemaker_args.get("TrainingJobName") or launch_project.run_id
- create_training_job = event_loop_thread_exec(sagemaker_client.create_training_job)
- resp = await create_training_job(**sagemaker_args)
- if resp.get("TrainingJobArn") is None:
- raise LaunchError("Failed to create training job when submitting to SageMaker")
- run = SagemakerSubmittedRun(training_job_name, sagemaker_client, log_client)
- wandb.termlog(
- f"{LOG_PREFIX}Run job submitted with arn: {resp.get('TrainingJobArn')}"
- )
- url = f"https://{sagemaker_client.meta.region_name}.console.aws.amazon.com/sagemaker/home?region={sagemaker_client.meta.region_name}#/jobs/{training_job_name}"
- wandb.termlog(f"{LOG_PREFIX}See training job status at: {url}")
- return run
- def get_role_arn(
- sagemaker_args: dict[str, Any],
- backend_config: dict[str, Any],
- account_id: str,
- partition: str,
- ) -> str:
- """Get the role arn from the sagemaker args or the backend config."""
- role_arn = sagemaker_args.get("RoleArn") or sagemaker_args.get("role_arn")
- if role_arn is None:
- role_arn = backend_config.get("runner", {}).get("role_arn")
- if role_arn is None or not isinstance(role_arn, str):
- raise LaunchError(
- "AWS sagemaker require a string RoleArn set this by adding a `RoleArn` key to the sagemaker"
- "field of resource_args"
- )
- if role_arn.startswith(f"arn:{partition}:iam::"):
- return role_arn # type: ignore
- return f"arn:{partition}:iam::{account_id}:role/{role_arn}"
|