from __future__ import annotations import asyncio import logging from typing import Any if False: from google.cloud import aiplatform # type: ignore # noqa: F401 from wandb.apis.internal import Api from wandb.util import get_module from .._project_spec import LaunchProject from ..environment.gcp_environment import GcpEnvironment from ..errors import LaunchError from ..registry.abstract import AbstractRegistry from ..utils import MAX_ENV_LENGTHS, PROJECT_SYNCHRONOUS, event_loop_thread_exec from .abstract import AbstractRun, AbstractRunner, Status GCP_CONSOLE_URI = "https://console.cloud.google.com" _logger = logging.getLogger(__name__) WANDB_RUN_ID_KEY = "wandb-run-id" class VertexSubmittedRun(AbstractRun): def __init__(self, job: Any) -> None: self._job = job @property def id(self) -> str: # numeric ID of the custom training job return self._job.name # type: ignore async def get_logs(self) -> str | None: # TODO: implement return None @property def name(self) -> str: return self._job.display_name # type: ignore @property def gcp_region(self) -> str: return self._job.location # type: ignore @property def gcp_project(self) -> str: return self._job.project # type: ignore def get_page_link(self) -> str: return f"{GCP_CONSOLE_URI}/vertex-ai/locations/{self.gcp_region}/training/{self.id}?project={self.gcp_project}" async def wait(self) -> bool: # TODO: run this in a separate thread. await self._job.wait() return (await self.get_status()).state == "finished" async def get_status(self) -> Status: job_state = str(self._job.state) # extract from type PipelineState if job_state == "JobState.JOB_STATE_SUCCEEDED": return Status("finished") if job_state == "JobState.JOB_STATE_FAILED": return Status("failed") if job_state == "JobState.JOB_STATE_RUNNING": return Status("running") if job_state == "JobState.JOB_STATE_PENDING": return Status("starting") return Status("unknown") async def cancel(self) -> None: self._job.cancel() class VertexRunner(AbstractRunner): """Runner class, uses a project to create a VertexSubmittedRun.""" def __init__( self, api: Api, backend_config: dict[str, Any], environment: GcpEnvironment, registry: AbstractRegistry, ) -> None: """Initialize a VertexRunner instance.""" super().__init__(api, backend_config) self.environment = environment self.registry = registry async def run( self, launch_project: LaunchProject, image_uri: str ) -> AbstractRun | None: """Run a Vertex job.""" full_resource_args = launch_project.fill_macros(image_uri) resource_args = full_resource_args.get("vertex") # We support setting under gcp-vertex for historical reasons. if not resource_args: resource_args = full_resource_args.get("gcp-vertex") if not resource_args: raise LaunchError( "No Vertex resource args specified. Specify args via --resource-args with a JSON file or string under top-level key gcp_vertex" ) spec_args = resource_args.get("spec", {}) run_args = resource_args.get("run", {}) synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS] entry_point = ( launch_project.override_entrypoint or launch_project.get_job_entry_point() ) # TODO: Set entrypoint in each container entry_cmd = [] if entry_point is not None: entry_cmd += entry_point.command entry_cmd += launch_project.override_args env_vars = launch_project.get_env_vars_dict( api=self._api, max_env_length=MAX_ENV_LENGTHS[self.__class__.__name__], ) worker_specs = spec_args.get("worker_pool_specs", []) if not worker_specs: raise LaunchError( "Vertex requires at least one worker pool spec. Please specify " "a worker pool spec in resource arguments under the key " "`vertex.spec.worker_pool_specs`." ) # TODO: Add entrypoint + args to each worker pool spec for spec in worker_specs: if not spec.get("container_spec"): raise LaunchError( "Vertex requires a container spec for each worker pool spec. " "Please specify a container spec in resource arguments under " "the key `vertex.spec.worker_pool_specs[].container_spec`." ) spec["container_spec"]["command"] = entry_cmd # Add our env vars to user supplied env vars env = spec["container_spec"].get("env", []) env.extend( [{"name": key, "value": value} for key, value in env_vars.items()] ) spec["container_spec"]["env"] = env if not spec_args.get("staging_bucket"): raise LaunchError( "Vertex requires a staging bucket. Please specify a staging bucket " "in resource arguments under the key `vertex.spec.staging_bucket`." ) _logger.info("Launching Vertex job...") submitted_run = await launch_vertex_job( launch_project, spec_args, run_args, self.environment, synchronous, ) return submitted_run async def launch_vertex_job( launch_project: LaunchProject, spec_args: dict[str, Any], run_args: dict[str, Any], environment: GcpEnvironment, synchronous: bool = False, ) -> VertexSubmittedRun: try: await environment.verify() aiplatform = get_module( "google.cloud.aiplatform", "VertexRunner requires google.cloud.aiplatform to be installed", ) init = event_loop_thread_exec(aiplatform.init) await init( project=environment.project, location=environment.region, staging_bucket=spec_args.get("staging_bucket"), credentials=await environment.get_credentials(), ) labels = spec_args.get("labels", {}) labels[WANDB_RUN_ID_KEY] = launch_project.run_id job = aiplatform.CustomJob( display_name=launch_project.name, worker_pool_specs=spec_args.get("worker_pool_specs"), base_output_dir=spec_args.get("base_output_dir"), encryption_spec_key_name=spec_args.get("encryption_spec_key_name"), labels=labels, ) execution_kwargs = dict( timeout=run_args.get("timeout"), service_account=run_args.get("service_account"), network=run_args.get("network"), enable_web_access=run_args.get("enable_web_access", False), experiment=run_args.get("experiment"), experiment_run=run_args.get("experiment_run"), tensorboard=run_args.get("tensorboard"), restart_job_on_worker_restart=run_args.get( "restart_job_on_worker_restart", False ), ) # Unclear if there are exceptions that can be thrown where we should # retry instead of erroring. For now, just catch all exceptions and they # go to the UI for the user to interpret. except Exception as e: raise LaunchError(f"Failed to create Vertex job: {e}") if synchronous: run = event_loop_thread_exec(job.run) await run(**execution_kwargs, sync=True) else: submit = event_loop_thread_exec(job.submit) await submit(**execution_kwargs) submitted_run = VertexSubmittedRun(job) interval = 1 while not getattr(job._gca_resource, "name", None): # give time for the gcp job object to be created and named, this should only loop a couple times max await asyncio.sleep(interval) interval = min(30, interval * 2) return submitted_run