sagemaker_runner.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. """Implementation of the SageMakerRunner class."""
  2. from __future__ import annotations
  3. import asyncio
  4. import logging
  5. from typing import Any, cast
  6. if False:
  7. import boto3 # type: ignore
  8. import wandb
  9. from wandb.apis.internal import Api
  10. from wandb.sdk.launch.environment.aws_environment import AwsEnvironment
  11. from wandb.sdk.launch.errors import LaunchError
  12. from .._project_spec import EntryPoint, LaunchProject
  13. from ..registry.abstract import AbstractRegistry
  14. from ..utils import (
  15. LOG_PREFIX,
  16. MAX_ENV_LENGTHS,
  17. PROJECT_SYNCHRONOUS,
  18. event_loop_thread_exec,
  19. to_camel_case,
  20. )
  21. from .abstract import AbstractRun, AbstractRunner, Status
  22. _logger = logging.getLogger(__name__)
  23. class SagemakerSubmittedRun(AbstractRun):
  24. """Instance of ``AbstractRun`` corresponding to a subprocess launched to run an entry point command on aws sagemaker."""
  25. def __init__(
  26. self,
  27. training_job_name: str,
  28. client: boto3.Client,
  29. log_client: boto3.Client | None = None,
  30. ) -> None:
  31. super().__init__()
  32. self.client = client
  33. self.log_client = log_client
  34. self.training_job_name = training_job_name
  35. self._status = Status("running")
  36. @property
  37. def id(self) -> str:
  38. return f"sagemaker-{self.training_job_name}"
  39. async def get_logs(self) -> str | None:
  40. if self.log_client is None:
  41. return None
  42. try:
  43. describe_log_streams = event_loop_thread_exec(
  44. self.log_client.describe_log_streams
  45. )
  46. describe_res = await describe_log_streams(
  47. logGroupName="/aws/sagemaker/TrainingJobs",
  48. logStreamNamePrefix=self.training_job_name,
  49. )
  50. if len(describe_res["logStreams"]) == 0:
  51. wandb.termwarn(
  52. f"Failed to get logs for training job: {self.training_job_name}"
  53. )
  54. return None
  55. log_name = describe_res["logStreams"][0]["logStreamName"]
  56. get_log_events = event_loop_thread_exec(self.log_client.get_log_events)
  57. res = await get_log_events(
  58. logGroupName="/aws/sagemaker/TrainingJobs",
  59. logStreamName=log_name,
  60. )
  61. assert "events" in res
  62. return "\n".join(
  63. [f"{event['timestamp']}:{event['message']}" for event in res["events"]]
  64. )
  65. except self.log_client.exceptions.ResourceNotFoundException:
  66. wandb.termwarn(
  67. f"Failed to get logs for training job: {self.training_job_name}"
  68. )
  69. return None
  70. except Exception as e:
  71. wandb.termwarn(
  72. f"Failed to handle logs for training job: {self.training_job_name} with error {str(e)}"
  73. )
  74. return None
  75. async def wait(self) -> bool:
  76. while True:
  77. status_state = (await self.get_status()).state
  78. wandb.termlog(
  79. f"{LOG_PREFIX}Training job {self.training_job_name} status: {status_state}"
  80. )
  81. if status_state in ["stopped", "failed", "finished"]:
  82. break
  83. await asyncio.sleep(5)
  84. return status_state == "finished"
  85. async def cancel(self) -> None:
  86. # Interrupt child process if it hasn't already exited
  87. status = await self.get_status()
  88. if status.state == "running":
  89. self.client.stop_training_job(TrainingJobName=self.training_job_name)
  90. await self.wait()
  91. async def get_status(self) -> Status:
  92. describe_training_job = event_loop_thread_exec(
  93. self.client.describe_training_job
  94. )
  95. job_status = (
  96. await describe_training_job(TrainingJobName=self.training_job_name)
  97. )["TrainingJobStatus"]
  98. if job_status == "Completed" or job_status == "Stopped":
  99. self._status = Status("finished")
  100. elif job_status == "Failed":
  101. self._status = Status("failed")
  102. elif job_status == "Stopping":
  103. self._status = Status("stopping")
  104. elif job_status == "InProgress":
  105. self._status = Status("running")
  106. return self._status
  107. class SageMakerRunner(AbstractRunner):
  108. """Runner class, uses a project to create a SagemakerSubmittedRun."""
  109. def __init__(
  110. self,
  111. api: Api,
  112. backend_config: dict[str, Any],
  113. environment: AwsEnvironment,
  114. registry: AbstractRegistry,
  115. ) -> None:
  116. """Initialize the SagemakerRunner.
  117. Arguments:
  118. api (Api): The API instance.
  119. backend_config (Dict[str, Any]): The backend configuration.
  120. environment (AwsEnvironment): The AWS environment.
  121. Raises:
  122. LaunchError: If the runner cannot be initialized.
  123. """
  124. super().__init__(api, backend_config)
  125. self.environment = environment
  126. self.registry = registry
  127. async def run(
  128. self,
  129. launch_project: LaunchProject,
  130. image_uri: str,
  131. ) -> AbstractRun | None:
  132. """Run a project on Amazon Sagemaker.
  133. Arguments:
  134. launch_project (LaunchProject): The project to run.
  135. Returns:
  136. Optional[AbstractRun]: The run instance.
  137. Raises:
  138. LaunchError: If the launch is unsuccessful.
  139. """
  140. _logger.info("using AWSSagemakerRunner")
  141. given_sagemaker_args = launch_project.resource_args.get("sagemaker")
  142. if given_sagemaker_args is None:
  143. raise LaunchError(
  144. "No sagemaker args specified. Specify sagemaker args in resource_args"
  145. )
  146. default_output_path = self.backend_config.get("runner", {}).get(
  147. "s3_output_path"
  148. )
  149. if default_output_path is not None and not default_output_path.startswith(
  150. "s3://"
  151. ):
  152. default_output_path = f"s3://{default_output_path}"
  153. session = await self.environment.get_session()
  154. client = await event_loop_thread_exec(session.client)("sts")
  155. caller_id = client.get_caller_identity()
  156. account_id = caller_id["Account"]
  157. _logger.info(f"Using account ID {account_id}")
  158. partition = await self.environment.get_partition()
  159. role_arn = get_role_arn(
  160. given_sagemaker_args, self.backend_config, account_id, partition
  161. )
  162. # Create a sagemaker client to launch the job.
  163. sagemaker_client = session.client("sagemaker")
  164. log_client = None
  165. try:
  166. log_client = session.client("logs")
  167. except Exception as e:
  168. wandb.termwarn(
  169. f"Failed to connect to cloudwatch logs with error {str(e)}, logs will not be available"
  170. )
  171. # if the user provided the image they want to use, use that, but warn it won't have swappable artifacts
  172. if (
  173. given_sagemaker_args.get("AlgorithmSpecification", {}).get("TrainingImage")
  174. is not None
  175. ):
  176. sagemaker_args = build_sagemaker_args(
  177. launch_project,
  178. self._api,
  179. role_arn,
  180. launch_project.override_entrypoint,
  181. launch_project.override_args,
  182. MAX_ENV_LENGTHS[self.__class__.__name__],
  183. given_sagemaker_args.get("AlgorithmSpecification", {}).get(
  184. "TrainingImage"
  185. ),
  186. default_output_path,
  187. )
  188. _logger.info(
  189. f"Launching sagemaker job on user supplied image with args: {sagemaker_args}"
  190. )
  191. run = await launch_sagemaker_job(
  192. launch_project, sagemaker_args, sagemaker_client, log_client
  193. )
  194. if self.backend_config[PROJECT_SYNCHRONOUS]:
  195. await run.wait()
  196. return run
  197. _logger.info("Connecting to sagemaker client")
  198. entry_point = (
  199. launch_project.override_entrypoint or launch_project.get_job_entry_point()
  200. )
  201. command_args = []
  202. if entry_point is not None:
  203. command_args += entry_point.command
  204. command_args += launch_project.override_args
  205. if command_args:
  206. command_str = " ".join(command_args)
  207. wandb.termlog(
  208. f"{LOG_PREFIX}Launching run on sagemaker with entrypoint: {command_str}"
  209. )
  210. else:
  211. wandb.termlog(
  212. f"{LOG_PREFIX}Launching run on sagemaker with user-provided entrypoint in image"
  213. )
  214. sagemaker_args = build_sagemaker_args(
  215. launch_project,
  216. self._api,
  217. role_arn,
  218. entry_point,
  219. launch_project.override_args,
  220. MAX_ENV_LENGTHS[self.__class__.__name__],
  221. image_uri,
  222. default_output_path,
  223. )
  224. _logger.info(f"Launching sagemaker job with args: {sagemaker_args}")
  225. run = await launch_sagemaker_job(
  226. launch_project, sagemaker_args, sagemaker_client, log_client
  227. )
  228. if self.backend_config[PROJECT_SYNCHRONOUS]:
  229. await run.wait()
  230. return run
  231. def merge_image_uri_with_algorithm_specification(
  232. algorithm_specification: dict[str, Any] | None,
  233. image_uri: str | None,
  234. entrypoint_command: list[str],
  235. args: list[str] | None,
  236. ) -> dict[str, Any]:
  237. """Create an AWS AlgorithmSpecification.
  238. AWS Sagemaker algorithms require a training image and an input mode. If the user
  239. does not specify the specification themselves, define the spec minimally using these
  240. two fields. Otherwise, if they specify the AlgorithmSpecification set the training
  241. image if it is not set.
  242. """
  243. if algorithm_specification is None:
  244. algorithm_specification = {
  245. "TrainingImage": image_uri,
  246. "TrainingInputMode": "File",
  247. }
  248. else:
  249. if image_uri:
  250. algorithm_specification["TrainingImage"] = image_uri
  251. if entrypoint_command:
  252. algorithm_specification["ContainerEntrypoint"] = entrypoint_command
  253. if args:
  254. algorithm_specification["ContainerArguments"] = args
  255. if algorithm_specification["TrainingImage"] is None:
  256. raise LaunchError("Failed determine tag for training image")
  257. return algorithm_specification
  258. def build_sagemaker_args(
  259. launch_project: LaunchProject,
  260. api: Api,
  261. role_arn: str,
  262. entry_point: EntryPoint | None,
  263. args: list[str] | None,
  264. max_env_length: int,
  265. image_uri: str,
  266. default_output_path: str | None = None,
  267. ) -> dict[str, Any]:
  268. sagemaker_args: dict[str, Any] = {}
  269. resource_args = launch_project.fill_macros(image_uri)
  270. given_sagemaker_args: dict[str, Any] | None = resource_args.get("sagemaker")
  271. if given_sagemaker_args is None:
  272. raise LaunchError(
  273. "No sagemaker args specified. Specify sagemaker args in resource_args"
  274. )
  275. if (
  276. given_sagemaker_args.get("OutputDataConfig") is None
  277. and default_output_path is not None
  278. ):
  279. sagemaker_args["OutputDataConfig"] = {"S3OutputPath": default_output_path}
  280. else:
  281. sagemaker_args["OutputDataConfig"] = given_sagemaker_args.get(
  282. "OutputDataConfig"
  283. )
  284. if sagemaker_args.get("OutputDataConfig") is None:
  285. raise LaunchError(
  286. "Sagemaker launcher requires an OutputDataConfig Sagemaker resource argument"
  287. )
  288. training_job_name = cast(
  289. str, (given_sagemaker_args.get("TrainingJobName") or launch_project.run_id)
  290. )
  291. sagemaker_args["TrainingJobName"] = training_job_name
  292. entry_cmd = entry_point.command if entry_point else []
  293. sagemaker_args["AlgorithmSpecification"] = (
  294. merge_image_uri_with_algorithm_specification(
  295. given_sagemaker_args.get(
  296. "AlgorithmSpecification",
  297. given_sagemaker_args.get("algorithm_specification"),
  298. ),
  299. image_uri,
  300. entry_cmd,
  301. args,
  302. )
  303. )
  304. sagemaker_args["RoleArn"] = role_arn
  305. camel_case_args = {
  306. to_camel_case(key): item for key, item in given_sagemaker_args.items()
  307. }
  308. sagemaker_args = {
  309. **camel_case_args,
  310. **sagemaker_args,
  311. }
  312. if sagemaker_args.get("ResourceConfig") is None:
  313. raise LaunchError(
  314. "Sagemaker launcher requires a ResourceConfig resource argument"
  315. )
  316. if sagemaker_args.get("StoppingCondition") is None:
  317. raise LaunchError(
  318. "Sagemaker launcher requires a StoppingCondition resource argument"
  319. )
  320. given_env = given_sagemaker_args.get(
  321. "Environment", sagemaker_args.get("environment", {})
  322. )
  323. calced_env = launch_project.get_env_vars_dict(api, max_env_length)
  324. total_env = {**calced_env, **given_env}
  325. sagemaker_args["Environment"] = total_env
  326. # Add wandb tag
  327. tags = sagemaker_args.get("Tags", [])
  328. tags.append({"Key": "WandbRunId", "Value": launch_project.run_id})
  329. sagemaker_args["Tags"] = tags
  330. # remove args that were passed in for launch but not passed to sagemaker
  331. sagemaker_args.pop("EcrRepoName", None)
  332. sagemaker_args.pop("region", None)
  333. sagemaker_args.pop("profile", None)
  334. # clear the args that are None so they are not passed
  335. filtered_args = {k: v for k, v in sagemaker_args.items() if v is not None}
  336. return filtered_args
  337. async def launch_sagemaker_job(
  338. launch_project: LaunchProject,
  339. sagemaker_args: dict[str, Any],
  340. sagemaker_client: boto3.Client,
  341. log_client: boto3.Client | None = None,
  342. ) -> SagemakerSubmittedRun:
  343. training_job_name = sagemaker_args.get("TrainingJobName") or launch_project.run_id
  344. create_training_job = event_loop_thread_exec(sagemaker_client.create_training_job)
  345. resp = await create_training_job(**sagemaker_args)
  346. if resp.get("TrainingJobArn") is None:
  347. raise LaunchError("Failed to create training job when submitting to SageMaker")
  348. run = SagemakerSubmittedRun(training_job_name, sagemaker_client, log_client)
  349. wandb.termlog(
  350. f"{LOG_PREFIX}Run job submitted with arn: {resp.get('TrainingJobArn')}"
  351. )
  352. url = f"https://{sagemaker_client.meta.region_name}.console.aws.amazon.com/sagemaker/home?region={sagemaker_client.meta.region_name}#/jobs/{training_job_name}"
  353. wandb.termlog(f"{LOG_PREFIX}See training job status at: {url}")
  354. return run
  355. def get_role_arn(
  356. sagemaker_args: dict[str, Any],
  357. backend_config: dict[str, Any],
  358. account_id: str,
  359. partition: str,
  360. ) -> str:
  361. """Get the role arn from the sagemaker args or the backend config."""
  362. role_arn = sagemaker_args.get("RoleArn") or sagemaker_args.get("role_arn")
  363. if role_arn is None:
  364. role_arn = backend_config.get("runner", {}).get("role_arn")
  365. if role_arn is None or not isinstance(role_arn, str):
  366. raise LaunchError(
  367. "AWS sagemaker require a string RoleArn set this by adding a `RoleArn` key to the sagemaker"
  368. "field of resource_args"
  369. )
  370. if role_arn.startswith(f"arn:{partition}:iam::"):
  371. return role_arn # type: ignore
  372. return f"arn:{partition}:iam::{account_id}:role/{role_arn}"