vertex_runner.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. from __future__ import annotations
  2. import asyncio
  3. import logging
  4. from typing import Any
  5. if False:
  6. from google.cloud import aiplatform # type: ignore # noqa: F401
  7. from wandb.apis.internal import Api
  8. from wandb.util import get_module
  9. from .._project_spec import LaunchProject
  10. from ..environment.gcp_environment import GcpEnvironment
  11. from ..errors import LaunchError
  12. from ..registry.abstract import AbstractRegistry
  13. from ..utils import MAX_ENV_LENGTHS, PROJECT_SYNCHRONOUS, event_loop_thread_exec
  14. from .abstract import AbstractRun, AbstractRunner, Status
  15. GCP_CONSOLE_URI = "https://console.cloud.google.com"
  16. _logger = logging.getLogger(__name__)
  17. WANDB_RUN_ID_KEY = "wandb-run-id"
  18. class VertexSubmittedRun(AbstractRun):
  19. def __init__(self, job: Any) -> None:
  20. self._job = job
  21. @property
  22. def id(self) -> str:
  23. # numeric ID of the custom training job
  24. return self._job.name # type: ignore
  25. async def get_logs(self) -> str | None:
  26. # TODO: implement
  27. return None
  28. @property
  29. def name(self) -> str:
  30. return self._job.display_name # type: ignore
  31. @property
  32. def gcp_region(self) -> str:
  33. return self._job.location # type: ignore
  34. @property
  35. def gcp_project(self) -> str:
  36. return self._job.project # type: ignore
  37. def get_page_link(self) -> str:
  38. return f"{GCP_CONSOLE_URI}/vertex-ai/locations/{self.gcp_region}/training/{self.id}?project={self.gcp_project}"
  39. async def wait(self) -> bool:
  40. # TODO: run this in a separate thread.
  41. await self._job.wait()
  42. return (await self.get_status()).state == "finished"
  43. async def get_status(self) -> Status:
  44. job_state = str(self._job.state) # extract from type PipelineState
  45. if job_state == "JobState.JOB_STATE_SUCCEEDED":
  46. return Status("finished")
  47. if job_state == "JobState.JOB_STATE_FAILED":
  48. return Status("failed")
  49. if job_state == "JobState.JOB_STATE_RUNNING":
  50. return Status("running")
  51. if job_state == "JobState.JOB_STATE_PENDING":
  52. return Status("starting")
  53. return Status("unknown")
  54. async def cancel(self) -> None:
  55. self._job.cancel()
  56. class VertexRunner(AbstractRunner):
  57. """Runner class, uses a project to create a VertexSubmittedRun."""
  58. def __init__(
  59. self,
  60. api: Api,
  61. backend_config: dict[str, Any],
  62. environment: GcpEnvironment,
  63. registry: AbstractRegistry,
  64. ) -> None:
  65. """Initialize a VertexRunner instance."""
  66. super().__init__(api, backend_config)
  67. self.environment = environment
  68. self.registry = registry
  69. async def run(
  70. self, launch_project: LaunchProject, image_uri: str
  71. ) -> AbstractRun | None:
  72. """Run a Vertex job."""
  73. full_resource_args = launch_project.fill_macros(image_uri)
  74. resource_args = full_resource_args.get("vertex")
  75. # We support setting under gcp-vertex for historical reasons.
  76. if not resource_args:
  77. resource_args = full_resource_args.get("gcp-vertex")
  78. if not resource_args:
  79. raise LaunchError(
  80. "No Vertex resource args specified. Specify args via --resource-args with a JSON file or string under top-level key gcp_vertex"
  81. )
  82. spec_args = resource_args.get("spec", {})
  83. run_args = resource_args.get("run", {})
  84. synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
  85. entry_point = (
  86. launch_project.override_entrypoint or launch_project.get_job_entry_point()
  87. )
  88. # TODO: Set entrypoint in each container
  89. entry_cmd = []
  90. if entry_point is not None:
  91. entry_cmd += entry_point.command
  92. entry_cmd += launch_project.override_args
  93. env_vars = launch_project.get_env_vars_dict(
  94. api=self._api,
  95. max_env_length=MAX_ENV_LENGTHS[self.__class__.__name__],
  96. )
  97. worker_specs = spec_args.get("worker_pool_specs", [])
  98. if not worker_specs:
  99. raise LaunchError(
  100. "Vertex requires at least one worker pool spec. Please specify "
  101. "a worker pool spec in resource arguments under the key "
  102. "`vertex.spec.worker_pool_specs`."
  103. )
  104. # TODO: Add entrypoint + args to each worker pool spec
  105. for spec in worker_specs:
  106. if not spec.get("container_spec"):
  107. raise LaunchError(
  108. "Vertex requires a container spec for each worker pool spec. "
  109. "Please specify a container spec in resource arguments under "
  110. "the key `vertex.spec.worker_pool_specs[].container_spec`."
  111. )
  112. spec["container_spec"]["command"] = entry_cmd
  113. # Add our env vars to user supplied env vars
  114. env = spec["container_spec"].get("env", [])
  115. env.extend(
  116. [{"name": key, "value": value} for key, value in env_vars.items()]
  117. )
  118. spec["container_spec"]["env"] = env
  119. if not spec_args.get("staging_bucket"):
  120. raise LaunchError(
  121. "Vertex requires a staging bucket. Please specify a staging bucket "
  122. "in resource arguments under the key `vertex.spec.staging_bucket`."
  123. )
  124. _logger.info("Launching Vertex job...")
  125. submitted_run = await launch_vertex_job(
  126. launch_project,
  127. spec_args,
  128. run_args,
  129. self.environment,
  130. synchronous,
  131. )
  132. return submitted_run
  133. async def launch_vertex_job(
  134. launch_project: LaunchProject,
  135. spec_args: dict[str, Any],
  136. run_args: dict[str, Any],
  137. environment: GcpEnvironment,
  138. synchronous: bool = False,
  139. ) -> VertexSubmittedRun:
  140. try:
  141. await environment.verify()
  142. aiplatform = get_module(
  143. "google.cloud.aiplatform",
  144. "VertexRunner requires google.cloud.aiplatform to be installed",
  145. )
  146. init = event_loop_thread_exec(aiplatform.init)
  147. await init(
  148. project=environment.project,
  149. location=environment.region,
  150. staging_bucket=spec_args.get("staging_bucket"),
  151. credentials=await environment.get_credentials(),
  152. )
  153. labels = spec_args.get("labels", {})
  154. labels[WANDB_RUN_ID_KEY] = launch_project.run_id
  155. job = aiplatform.CustomJob(
  156. display_name=launch_project.name,
  157. worker_pool_specs=spec_args.get("worker_pool_specs"),
  158. base_output_dir=spec_args.get("base_output_dir"),
  159. encryption_spec_key_name=spec_args.get("encryption_spec_key_name"),
  160. labels=labels,
  161. )
  162. execution_kwargs = dict(
  163. timeout=run_args.get("timeout"),
  164. service_account=run_args.get("service_account"),
  165. network=run_args.get("network"),
  166. enable_web_access=run_args.get("enable_web_access", False),
  167. experiment=run_args.get("experiment"),
  168. experiment_run=run_args.get("experiment_run"),
  169. tensorboard=run_args.get("tensorboard"),
  170. restart_job_on_worker_restart=run_args.get(
  171. "restart_job_on_worker_restart", False
  172. ),
  173. )
  174. # Unclear if there are exceptions that can be thrown where we should
  175. # retry instead of erroring. For now, just catch all exceptions and they
  176. # go to the UI for the user to interpret.
  177. except Exception as e:
  178. raise LaunchError(f"Failed to create Vertex job: {e}")
  179. if synchronous:
  180. run = event_loop_thread_exec(job.run)
  181. await run(**execution_kwargs, sync=True)
  182. else:
  183. submit = event_loop_thread_exec(job.submit)
  184. await submit(**execution_kwargs)
  185. submitted_run = VertexSubmittedRun(job)
  186. interval = 1
  187. while not getattr(job._gca_resource, "name", None):
  188. # give time for the gcp job object to be created and named, this should only loop a couple times max
  189. await asyncio.sleep(interval)
  190. interval = min(30, interval * 2)
  191. return submitted_run