"""Implementation of KubernetesRunner class for wandb launch.""" from __future__ import annotations import asyncio import base64 import datetime import json import logging import os import shlex import time from collections.abc import Iterator from typing import Any import yaml import wandb from wandb.apis.internal import Api from wandb.sdk.launch.agent.agent import LaunchAgent from wandb.sdk.launch.environment.abstract import AbstractEnvironment from wandb.sdk.launch.registry.abstract import AbstractRegistry from wandb.sdk.launch.registry.azure_container_registry import AzureContainerRegistry from wandb.sdk.launch.registry.local_registry import LocalRegistry from wandb.sdk.launch.runner.abstract import Status from wandb.sdk.launch.runner.kubernetes_monitor import ( WANDB_K8S_LABEL_AGENT, WANDB_K8S_LABEL_AUXILIARY_RESOURCE, WANDB_K8S_LABEL_MONITOR, WANDB_K8S_RUN_ID, CustomResource, LaunchKubernetesMonitor, ) from wandb.sdk.launch.utils import ( recursive_macro_sub, sanitize_identifiers_for_k8s, yield_containers, ) from wandb.sdk.lib.retry import ExponentialBackoff, retry_async from wandb.util import get_module from .._project_spec import EntryPoint, LaunchProject from ..errors import LaunchError from ..utils import ( CODE_MOUNT_DIR, LOG_PREFIX, MAX_ENV_LENGTHS, PROJECT_SYNCHRONOUS, get_kube_context_and_api_client, make_k8s_label_safe, make_name_dns_safe, ) from .abstract import AbstractRun, AbstractRunner get_module( "kubernetes_asyncio", required="Kubernetes runner requires the kubernetes package. Please install it with `pip install wandb[launch]`.", ) import kubernetes_asyncio # type: ignore # noqa: E402 from kubernetes_asyncio import client # noqa: E402 from kubernetes_asyncio.client.api.apps_v1_api import ( # type: ignore # noqa: E402 AppsV1Api, ) from kubernetes_asyncio.client.api.batch_v1_api import ( # type: ignore # noqa: E402 BatchV1Api, ) from kubernetes_asyncio.client.api.core_v1_api import ( # type: ignore # noqa: E402 CoreV1Api, ) from kubernetes_asyncio.client.api.custom_objects_api import ( # type: ignore # noqa: E402 CustomObjectsApi, ) from kubernetes_asyncio.client.api.networking_v1_api import ( # type: ignore # noqa: E402 NetworkingV1Api, ) from kubernetes_asyncio.client.models.v1_secret import ( # type: ignore # noqa: E402 V1Secret, ) from kubernetes_asyncio.client.rest import ApiException # type: ignore # noqa: E402 TIMEOUT = 5 API_KEY_SECRET_MAX_RETRIES = 5 _logger = logging.getLogger(__name__) SOURCE_CODE_PVC_MOUNT_PATH = os.environ.get("WANDB_LAUNCH_CODE_PVC_MOUNT_PATH") SOURCE_CODE_PVC_NAME = os.environ.get("WANDB_LAUNCH_CODE_PVC_NAME") class KubernetesSubmittedRun(AbstractRun): """Wrapper for a launched run on Kubernetes.""" def __init__( self, batch_api: BatchV1Api, core_api: CoreV1Api, apps_api: AppsV1Api, network_api: NetworkingV1Api, name: str, namespace: str | None = "default", secret: V1Secret | None = None, auxiliary_resource_label_key: str | None = None, ) -> None: """Initialize a KubernetesSubmittedRun. Other implementations of the AbstractRun interface poll on the run when `get_status` is called, but KubernetesSubmittedRun uses Kubernetes watch streams to update the run status. One thread handles events from the job object and another thread handles events from the rank 0 pod. These threads updated the `_status` attributed of the KubernetesSubmittedRun object. When `get_status` is called, the `_status` attribute is returned. Args: batch_api: Kubernetes BatchV1Api object. core_api: Kubernetes CoreV1Api object. network_api: Kubernetes NetworkV1Api object. name: Name of the job. namespace: Kubernetes namespace. secret: Kubernetes secret. Returns: None. """ self.batch_api = batch_api self.core_api = core_api self.apps_api = apps_api self.network_api = network_api self.name = name self.namespace = namespace self._fail_count = 0 self.secret = secret self.auxiliary_resource_label_key = auxiliary_resource_label_key @property def id(self) -> str: """Return the run id.""" return self.name async def get_logs(self) -> str | None: try: pods = await self.core_api.list_namespaced_pod( label_selector=f"job-name={self.name}", namespace=self.namespace ) pod_names = [pi.metadata.name for pi in pods.items] if not pod_names: wandb.termwarn(f"Found no pods for kubernetes job: {self.name}") return None logs = await self.core_api.read_namespaced_pod_log( name=pod_names[0], namespace=self.namespace ) if logs: return str(logs) else: wandb.termwarn(f"No logs for kubernetes pod(s): {pod_names}") return None except Exception as e: wandb.termerror(f"{LOG_PREFIX}Failed to get pod logs: {e}") return None async def wait(self) -> bool: """Wait for the run to finish. Returns: True if the run finished successfully, False otherwise. """ while True: status = await self.get_status() wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status.state}") if status.state in ["finished", "failed", "preempted"]: break await asyncio.sleep(5) await self._delete_secret() return ( status.state == "finished" ) # todo: not sure if this (copied from aws runner) is the right approach? should we return false on failure async def get_status(self) -> Status: status = LaunchKubernetesMonitor.get_status(self.name) if status in ["stopped", "failed", "finished", "preempted"]: await self._delete_secret() return status async def cancel(self) -> None: """Cancel the run.""" try: await self.batch_api.delete_namespaced_job( namespace=self.namespace, name=self.name, ) await self._delete_secret() except ApiException as e: raise LaunchError( f"Failed to delete Kubernetes Job {self.name} in namespace {self.namespace}: {str(e)}" ) from e async def _delete_secret(self) -> None: # Cleanup secret if not running in a helm-managed context if not os.environ.get("WANDB_RELEASE_NAME") and self.secret: await self.core_api.delete_namespaced_secret( name=self.secret.metadata.name, namespace=self.secret.metadata.namespace, ) self.secret = None async def _delete_auxiliary_resources_by_label(self) -> None: if self.auxiliary_resource_label_key is None: return label_selector = ( f"{WANDB_K8S_LABEL_AUXILIARY_RESOURCE}={self.auxiliary_resource_label_key}" ) try: resource_cleanups = [ (self.core_api, "service"), (self.batch_api, "job"), (self.core_api, "pod"), (self.core_api, "secret"), (self.apps_api, "deployment"), (self.network_api, "network_policy"), ] for api_client, resource_type in resource_cleanups: try: list_method = getattr( api_client, f"list_namespaced_{resource_type}" ) delete_method = getattr( api_client, f"delete_namespaced_{resource_type}" ) # List resources with our label resources = await list_method( namespace=self.namespace, label_selector=label_selector ) # Delete each resource for resource in resources.items: await delete_method( name=resource.metadata.name, namespace=self.namespace ) except (AttributeError, ApiException) as e: wandb.termwarn(f"Could not clean up {resource_type}: {e}") except Exception as e: wandb.termwarn(f"Failed to clean up some auxiliary resources: {e}") class CrdSubmittedRun(AbstractRun): """Run submitted to a CRD backend, e.g. Volcano.""" def __init__( self, group: str, version: str, plural: str, name: str, namespace: str, core_api: CoreV1Api, custom_api: CustomObjectsApi, ) -> None: """Create a run object for tracking the progress of a CRD. Args: group: The API group of the CRD. version: The API version of the CRD. plural: The plural name of the CRD. name: The name of the CRD instance. namespace: The namespace of the CRD instance. core_api: The Kubernetes core API client. custom_api: The Kubernetes custom object API client. Raises: LaunchError: If the CRD instance does not exist. """ self.group = group self.version = version self.plural = plural self.name = name self.namespace = namespace self.core_api = core_api self.custom_api = custom_api self._fail_count = 0 @property def id(self) -> str: """Get the name of the custom object.""" return self.name async def get_logs(self) -> str | None: """Get logs for custom object.""" # TODO: test more carefully once we release multi-node support logs: dict[str, str | None] = {} try: pods = await self.core_api.list_namespaced_pod( label_selector=f"wandb/run-id={self.name}", namespace=self.namespace ) pod_names = [pi.metadata.name for pi in pods.items] for pod_name in pod_names: logs[pod_name] = await self.core_api.read_namespaced_pod_log( name=pod_name, namespace=self.namespace ) except ApiException as e: wandb.termwarn(f"Failed to get logs for {self.name}: {str(e)}") return None if not logs: return None logs_as_array = [f"Pod {pod_name}:\n{log}" for pod_name, log in logs.items()] return "\n".join(logs_as_array) async def get_status(self) -> Status: """Get status of custom object.""" return LaunchKubernetesMonitor.get_status(self.name) async def cancel(self) -> None: """Cancel the custom object.""" try: await self.custom_api.delete_namespaced_custom_object( group=self.group, version=self.version, namespace=self.namespace, plural=self.plural, name=self.name, ) except ApiException as e: raise LaunchError( f"Failed to delete CRD {self.name} in namespace {self.namespace}: {str(e)}" ) from e async def wait(self) -> bool: """Wait for this custom object to finish running.""" while True: status = await self.get_status() wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}") if status.state in ["finished", "failed", "preempted"]: return status.state == "finished" await asyncio.sleep(5) class KubernetesRunner(AbstractRunner): """Launches runs onto kubernetes.""" def __init__( self, api: Api, backend_config: dict[str, Any], environment: AbstractEnvironment, registry: AbstractRegistry, ) -> None: """Create a Kubernetes runner. Args: api: The API client object. backend_config: The backend configuration. environment: The environment to launch runs into. Raises: LaunchError: If the Kubernetes configuration is invalid. """ super().__init__(api, backend_config) self.environment = environment self.registry = registry def get_namespace( self, resource_args: dict[str, Any], context: dict[str, Any] ) -> str: """Get the namespace to launch into. Args: resource_args: The resource args to launch. context: The k8s config context. Returns: The namespace to launch into. """ default_namespace = ( context["context"].get("namespace", "default") if context else "default" ) return ( # type: ignore[no-any-return] resource_args.get("metadata", {}).get("namespace") or resource_args.get( "namespace" ) # continue support for malformed namespace or self.backend_config.get("runner", {}).get("namespace") or default_namespace ) async def _inject_defaults( self, resource_args: dict[str, Any], launch_project: LaunchProject, image_uri: str, namespace: str, core_api: CoreV1Api, ) -> tuple[dict[str, Any], V1Secret | None]: """Apply our default values, return job dict and api key secret. Args: resource_args (Dict[str, Any]): The resource args to launch. launch_project (LaunchProject): The launch project. builder (Optional[AbstractBuilder]): The builder. namespace (str): The namespace. core_api (CoreV1Api): The core api. Returns: Tuple[Dict[str, Any], Optional["V1Secret"]]: The resource args and api key secret. """ job: dict[str, Any] = { "apiVersion": "batch/v1", "kind": "Job", } job.update(resource_args) job_metadata: dict[str, Any] = job.get("metadata", {}) job_spec: dict[str, Any] = {"backoffLimit": 0, "ttlSecondsAfterFinished": 60} job_spec.update(job.get("spec", {})) pod_template: dict[str, Any] = job_spec.get("template", {}) pod_spec: dict[str, Any] = {"restartPolicy": "Never"} pod_spec.update(pod_template.get("spec", {})) containers: list[dict[str, Any]] = pod_spec.get("containers", [{}]) # Add labels to job metadata job_metadata.setdefault("labels", {}) job_metadata["labels"][WANDB_K8S_RUN_ID] = launch_project.run_id job_metadata["labels"][WANDB_K8S_LABEL_MONITOR] = "true" if LaunchAgent.initialized(): job_metadata["labels"][WANDB_K8S_LABEL_AGENT] = LaunchAgent.name() # name precedence: name in spec > generated name if not job_metadata.get("name"): job_metadata["generateName"] = make_name_dns_safe( f"launch-{launch_project.target_entity}-{launch_project.target_project}-" ) job_metadata["namespace"] = namespace for i, cont in enumerate(containers): if "name" not in cont: cont["name"] = cont.get("name", "launch" + str(i)) if "securityContext" not in cont: cont["securityContext"] = { "allowPrivilegeEscalation": False, "capabilities": {"drop": ["ALL"]}, "seccompProfile": {"type": "RuntimeDefault"}, } entry_point = ( launch_project.override_entrypoint or launch_project.get_job_entry_point() ) if launch_project.docker_image: # dont specify run id if user provided image, could have multiple runs containers[0]["image"] = image_uri # TODO: handle secret pulling image from registry elif not any(["image" in cont for cont in containers]): assert entry_point is not None # in the non instance case we need to make an imagePullSecret # so the new job can pull the image containers[0]["image"] = image_uri secret = await maybe_create_imagepull_secret( core_api, self.registry, launch_project.run_id, namespace ) if secret is not None: pod_spec["imagePullSecrets"] = [ {"name": f"regcred-{launch_project.run_id}"} ] inject_entrypoint_and_args( containers, entry_point, launch_project.override_args, launch_project.override_entrypoint is not None, ) env_vars = launch_project.get_env_vars_dict( self._api, MAX_ENV_LENGTHS[self.__class__.__name__] ) api_key_secret = None for cont in containers: # Add our env vars to user supplied env vars env = cont.get("env") or [] for key, value in env_vars.items(): if ( key == "WANDB_API_KEY" and value and ( LaunchAgent.initialized() or self.backend_config[PROJECT_SYNCHRONOUS] ) ): # Override API key with secret. TODO: Do the same for other runners release_name = os.environ.get("WANDB_RELEASE_NAME") secret_name = "wandb-api-key" if release_name: secret_name += f"-{release_name}" else: secret_name += f"-{launch_project.run_id}" def handle_exception(e): wandb.termwarn( f"Exception when ensuring Kubernetes API key secret: {e}. Retrying..." ) api_key_secret = await retry_async( backoff=ExponentialBackoff( initial_sleep=datetime.timedelta(seconds=1), max_sleep=datetime.timedelta(minutes=1), max_retries=API_KEY_SECRET_MAX_RETRIES, ), fn=ensure_api_key_secret, on_exc=handle_exception, core_api=core_api, secret_name=secret_name, namespace=namespace, api_key=value, ) env.append( { "name": key, "valueFrom": { "secretKeyRef": { "name": secret_name, "key": "password", } }, } ) else: env.append({"name": key, "value": value}) cont["env"] = env pod_spec["containers"] = containers pod_template["spec"] = pod_spec job_spec["template"] = pod_template job["spec"] = job_spec job["metadata"] = job_metadata add_label_to_pods( job, WANDB_K8S_LABEL_MONITOR, "true", ) if launch_project.job_base_image: if SOURCE_CODE_PVC_NAME and SOURCE_CODE_PVC_MOUNT_PATH: apply_code_mount_configuration( job, launch_project, ) else: apply_code_mount_configuration_emptydir( job, launch_project, self._api, ) # Add wandb.ai/agent: current agent label on all pods if LaunchAgent.initialized(): add_label_to_pods( job, WANDB_K8S_LABEL_AGENT, LaunchAgent.name(), ) return job, api_key_secret async def _wait_for_resource_ready( self, api_client: kubernetes_asyncio.client.ApiClient, config: dict[str, Any], namespace: str, timeout_seconds: int = 300, ) -> None: """Wait for a Kubernetes resource to be ready. Args: api_client: The Kubernetes API client. config: The resource configuration. namespace: The namespace where the resource was created. timeout_seconds: Maximum time to wait for readiness. """ resource_kind = config.get("kind") resource_name = config.get("metadata", {}).get("name") if not resource_kind or not resource_name: wandb.termerror( f"{LOG_PREFIX}Cannot wait for resource without kind or name" ) return wandb.termlog( f"{LOG_PREFIX}Waiting for {resource_kind} '{resource_name}' to be ready..." ) start_time = time.time() if resource_kind == "Deployment": await self._wait_for_deployment_ready( api_client, resource_name, namespace, timeout_seconds ) elif resource_kind == "Service": await self._wait_for_service_ready( api_client, resource_name, namespace, timeout_seconds ) elif resource_kind == "Pod": await self._wait_for_pod_ready( api_client, resource_name, namespace, timeout_seconds ) else: wandb.termlog( f"{LOG_PREFIX}No specific readiness check for {resource_kind}, waiting 5 seconds..." ) await asyncio.sleep(5) elapsed = time.time() - start_time wandb.termlog( f"{LOG_PREFIX}{resource_kind} '{resource_name}' is ready after {elapsed:.1f}s" ) async def _wait_for_deployment_ready( self, api_client: kubernetes_asyncio.client.ApiClient, name: str, namespace: str, timeout_seconds: int, ) -> None: """Wait for a Deployment to be ready.""" apps_api = kubernetes_asyncio.client.AppsV1Api(api_client) async def check_deployment_ready(): deployment = await apps_api.read_namespaced_deployment( name=name, namespace=namespace ) status = deployment.status if status.ready_replicas and status.replicas: return status.ready_replicas >= status.replicas return False await self._wait_with_timeout(check_deployment_ready, timeout_seconds, name) async def _wait_for_service_ready( self, api_client: kubernetes_asyncio.client.ApiClient, name: str, namespace: str, timeout_seconds: int, ) -> None: """Wait for a Service to have endpoints.""" core_api = kubernetes_asyncio.client.CoreV1Api(api_client) async def check_service_ready(): endpoints = await core_api.read_namespaced_endpoints( name=name, namespace=namespace ) if endpoints.subsets: for subset in endpoints.subsets: if subset.addresses: # These are ready pod addresses return True return False await self._wait_with_timeout(check_service_ready, timeout_seconds, name) async def _wait_for_pod_ready( self, api_client: kubernetes_asyncio.client.ApiClient, name: str, namespace: str, timeout_seconds: int, ) -> None: """Wait for a Pod to be ready.""" core_api = kubernetes_asyncio.client.CoreV1Api(api_client) async def check_pod_ready(): pod = await core_api.read_namespaced_pod(name=name, namespace=namespace) if pod.status.phase == "Running": if pod.status.container_statuses: return all(status.ready for status in pod.status.container_statuses) return True return False await self._wait_with_timeout(check_pod_ready, timeout_seconds, name) async def _wait_with_timeout( self, check_func, timeout_seconds: int, name: str ) -> None: """Generic timeout wrapper for readiness checks.""" start_time = time.time() while time.time() - start_time < timeout_seconds: try: if await check_func(): return except kubernetes_asyncio.client.ApiException as e: if e.status == 404: pass else: wandb.termerror( f"{LOG_PREFIX}Error waiting for resource '{name}': {e}" ) raise except Exception as e: wandb.termerror(f"{LOG_PREFIX}Error waiting for resource '{name}': {e}") raise await asyncio.sleep(2) raise LaunchError( f"Resource '{name}' not ready within {timeout_seconds} seconds" ) async def _prepare_resource( self, api_client: kubernetes_asyncio.client.ApiClient, config: dict[str, Any], namespace: str, run_id: str, launch_project: LaunchProject, api_key_secret: V1Secret | None = None, wait_for_ready: bool = True, wait_timeout: int = 300, auxiliary_resource_label_value: str | None = None, ) -> None: """Prepare a service for launch. Args: api_client: The Kubernetes API client. config: The resource configuration to prepare. namespace: The namespace to create the resource in. run_id: The run ID to label the resource with. launch_project: The launch project to get environment variables from. api_key_secret: The API key secret to inject. wait_for_ready: Whether to wait for the resource to be ready after creation. wait_timeout: Maximum time in seconds to wait for resource readiness. """ config.setdefault("metadata", {}) config["metadata"].setdefault("labels", {}) config["metadata"]["labels"][WANDB_K8S_RUN_ID] = run_id config["metadata"]["labels"]["wandb.ai/created-by"] = "launch-agent" if auxiliary_resource_label_value: config["metadata"]["labels"][WANDB_K8S_LABEL_AUXILIARY_RESOURCE] = ( auxiliary_resource_label_value ) env_vars = launch_project.get_env_vars_dict( self._api, MAX_ENV_LENGTHS[self.__class__.__name__] ) wandb_config_env = { "WANDB_CONFIG": env_vars.get("WANDB_CONFIG", "{}"), } add_wandb_env(config, wandb_config_env) if auxiliary_resource_label_value: add_label_to_pods( config, WANDB_K8S_LABEL_AUXILIARY_RESOURCE, auxiliary_resource_label_value, ) if api_key_secret: for cont in yield_containers(config): env = cont.setdefault("env", []) env.append( { "name": "WANDB_API_KEY", "valueFrom": { "secretKeyRef": { "name": api_key_secret.metadata.name, "key": "password", } }, } ) cont["env"] = env try: sanitize_identifiers_for_k8s(config) await kubernetes_asyncio.utils.create_from_dict( api_client, config, namespace=namespace ) if wait_for_ready: await self._wait_for_resource_ready( api_client, config, namespace, wait_timeout ) except Exception as e: wandb.termerror(f"{LOG_PREFIX}Failed to create Kubernetes resource: {e}") raise LaunchError(f"Failed to create Kubernetes resource: {e}") async def run( self, launch_project: LaunchProject, image_uri: str ) -> AbstractRun | None: """Execute a launch project on Kubernetes. Args: launch_project: The launch project to execute. builder: The builder to use to build the image. Returns: The run object if the run was successful, otherwise None. """ await LaunchKubernetesMonitor.ensure_initialized() resource_args = launch_project.fill_macros(image_uri).get("kubernetes", {}) if not resource_args: wandb.termlog( f"{LOG_PREFIX}Note: no resource args specified. Add a " "Kubernetes yaml spec or other options in a json file " "with --resource-args ." ) _logger.info(f"Running Kubernetes job with resource args: {resource_args}") context, api_client = await get_kube_context_and_api_client( kubernetes_asyncio, resource_args ) # If using pvc for code mount, move code there. use_emptydir_code_mount = False if launch_project.job_base_image is not None: if SOURCE_CODE_PVC_NAME and SOURCE_CODE_PVC_MOUNT_PATH: code_subdir = launch_project.get_image_source_string() launch_project.change_project_dir( os.path.join(SOURCE_CODE_PVC_MOUNT_PATH, code_subdir) ) else: use_emptydir_code_mount = True # If the user specified an alternate api, we need will execute this # run by creating a custom object. api_version = resource_args.get("apiVersion", "batch/v1") if api_version not in ["batch/v1", "batch/v1beta1"]: env_vars = launch_project.get_env_vars_dict( self._api, MAX_ENV_LENGTHS[self.__class__.__name__] ) # Crawl the resource args and add our env vars to the containers. add_wandb_env(resource_args, env_vars) # Add our labels to the resource args. This is necessary for the # agent to find the custom object later on. resource_args["metadata"] = resource_args.get("metadata", {}) resource_args["metadata"]["labels"] = resource_args["metadata"].get( "labels", {} ) resource_args["metadata"]["labels"][WANDB_K8S_LABEL_MONITOR] = "true" # Crawl the resource arsg and add our labels to the pods. This is # necessary for the agent to find the pods later on. add_label_to_pods( resource_args, WANDB_K8S_LABEL_MONITOR, "true", ) # Add wandb.ai/agent: current agent label on all pods if LaunchAgent.initialized(): add_label_to_pods( resource_args, WANDB_K8S_LABEL_AGENT, LaunchAgent.name(), ) resource_args["metadata"]["labels"][WANDB_K8S_LABEL_AGENT] = ( LaunchAgent.name() ) if launch_project.job_base_image: if use_emptydir_code_mount: apply_code_mount_configuration_emptydir( resource_args, launch_project, self._api ) else: apply_code_mount_configuration(resource_args, launch_project) overrides = {} if launch_project.override_args: overrides["args"] = launch_project.override_args if launch_project.override_entrypoint: overrides["command"] = launch_project.override_entrypoint.command add_entrypoint_args_overrides( resource_args, overrides, ) api = client.CustomObjectsApi(api_client) # Infer the attributes of a custom object from the apiVersion and/or # a kind: attribute in the resource args. namespace = self.get_namespace(resource_args, context) group, version, *_ = api_version.split("/") group = resource_args.get("group", group) version = resource_args.get("version", version) kind = resource_args.get("kind", version) plural = f"{kind.lower()}s" custom_resource = CustomResource( group=group, version=version, plural=plural, ) LaunchKubernetesMonitor.monitor_namespace( namespace, custom_resource=custom_resource ) try: response = await api.create_namespaced_custom_object( group=group, version=version, namespace=namespace, plural=plural, body=resource_args, ) except ApiException as e: body = json.loads(e.body) body_yaml = yaml.dump(body) raise LaunchError( f"Error creating CRD of kind {kind}: {e.status} {e.reason}\n{body_yaml}" ) from e name = response.get("metadata", {}).get("name") _logger.info(f"Created {kind} {response['metadata']['name']}") submitted_run = CrdSubmittedRun( name=name, group=group, version=version, namespace=namespace, plural=plural, core_api=client.CoreV1Api(api_client), custom_api=api, ) if self.backend_config[PROJECT_SYNCHRONOUS]: await submitted_run.wait() return submitted_run batch_api = kubernetes_asyncio.client.BatchV1Api(api_client) core_api = kubernetes_asyncio.client.CoreV1Api(api_client) apps_api = kubernetes_asyncio.client.AppsV1Api(api_client) network_api = kubernetes_asyncio.client.NetworkingV1Api(api_client) namespace = self.get_namespace(resource_args, context) job, secret = await self._inject_defaults( resource_args, launch_project, image_uri, namespace, core_api ) update_dict = { "project_name": launch_project.target_project, "entity_name": launch_project.target_entity, "run_id": launch_project.run_id, "run_name": launch_project.name, "image_uri": image_uri, "author": launch_project.author, } update_dict.update(os.environ) additional_services: list[dict[str, Any]] = recursive_macro_sub( launch_project.launch_spec.get("additional_services", []), update_dict ) auxiliary_resource_label_value = make_k8s_label_safe( f"aux-{launch_project.target_entity}-{launch_project.target_project}-{launch_project.run_id}" ) if additional_services: wandb.termlog( f"{LOG_PREFIX}Creating additional services: {additional_services}" ) wait_for_ready = resource_args.get("wait_for_ready", True) wait_timeout = resource_args.get("wait_timeout", 300) await asyncio.gather( *[ self._prepare_resource( api_client, resource.get("config", {}), namespace, launch_project.run_id, launch_project, secret, wait_for_ready, wait_timeout, auxiliary_resource_label_value, ) for resource in additional_services if resource.get("config", {}) ] ) msg = "Creating Kubernetes job" if "name" in resource_args: msg += f": {resource_args['name']}" _logger.info(msg) try: response = await kubernetes_asyncio.utils.create_from_dict( api_client, job, namespace=namespace ) except kubernetes_asyncio.utils.FailToCreateError as e: for exc in e.api_exceptions: resp = json.loads(exc.body) msg = resp.get("message") code = resp.get("code") raise LaunchError( f"Failed to create Kubernetes job for run {launch_project.run_id} ({code} {exc.reason}): {msg}" ) except Exception as e: raise LaunchError( f"Unexpected exception when creating Kubernetes job: {str(e)}\n" ) job_response = response[0] job_name = job_response.metadata.name LaunchKubernetesMonitor.monitor_namespace(namespace) submitted_job = KubernetesSubmittedRun( batch_api, core_api, apps_api, network_api, job_name, namespace, secret, auxiliary_resource_label_value, ) if self.backend_config[PROJECT_SYNCHRONOUS]: await submitted_job.wait() return submitted_job def inject_entrypoint_and_args( containers: list[dict], entry_point: EntryPoint | None, override_args: list[str], should_override_entrypoint: bool, ) -> None: """Inject the entrypoint and args into the containers. Args: containers: The containers to inject the entrypoint and args into. entry_point: The entrypoint to inject. override_args: The args to inject. should_override_entrypoint: Whether to override the entrypoint. Returns: None """ for i in range(len(containers)): if override_args: containers[i]["args"] = override_args if entry_point and ( not containers[i].get("command") or should_override_entrypoint ): containers[i]["command"] = entry_point.command async def ensure_api_key_secret( core_api: CoreV1Api, secret_name: str, namespace: str, api_key: str, ) -> V1Secret: """Create a secret containing a user's wandb API key. Args: core_api: The Kubernetes CoreV1Api object. secret_name: The name to use for the secret. namespace: The namespace to create the secret in. api_key: The user's wandb API key Returns: The created secret """ secret_data = {"password": base64.b64encode(api_key.encode()).decode()} labels = {"wandb.ai/created-by": "launch-agent"} secret = client.V1Secret( data=secret_data, metadata=client.V1ObjectMeta( name=secret_name, namespace=namespace, labels=labels ), kind="Secret", type="kubernetes.io/basic-auth", ) try: try: return await core_api.create_namespaced_secret(namespace, secret) except ApiException as e: # 409 = conflict = secret already exists if e.status == 409: existing_secret = await core_api.read_namespaced_secret( name=secret_name, namespace=namespace ) if existing_secret.data != secret_data: # If it's a previous secret made by launch agent, clean it up if ( existing_secret.metadata.labels.get("wandb.ai/created-by") == "launch-agent" ): await core_api.delete_namespaced_secret( name=secret_name, namespace=namespace ) return await core_api.create_namespaced_secret( namespace, secret ) else: raise LaunchError( f"Kubernetes secret already exists in namespace {namespace} with incorrect data: {secret_name}" ) return existing_secret raise except Exception as e: raise LaunchError( f"Exception when ensuring Kubernetes API key secret: {str(e)}\n" ) async def maybe_create_imagepull_secret( core_api: CoreV1Api, registry: AbstractRegistry, run_id: str, namespace: str, ) -> V1Secret | None: """Create a secret for pulling images from a private registry. Args: core_api: The Kubernetes CoreV1Api object. registry: The registry to pull from. run_id: The run id. namespace: The namespace to create the secret in. Returns: A secret if one was created, otherwise None. """ secret = None if isinstance(registry, (LocalRegistry, AzureContainerRegistry)): # Secret not required return None uname, token = await registry.get_username_password() creds_info = { "auths": { registry.uri: { "auth": base64.b64encode(f"{uname}:{token}".encode()).decode(), # need an email but the use is deprecated "email": "deprecated@wandblaunch.com", } } } secret_data = { ".dockerconfigjson": base64.b64encode(json.dumps(creds_info).encode()).decode() } secret = client.V1Secret( data=secret_data, metadata=client.V1ObjectMeta(name=f"regcred-{run_id}", namespace=namespace), kind="Secret", type="kubernetes.io/dockerconfigjson", ) try: try: return await core_api.create_namespaced_secret(namespace, secret) except ApiException as e: # 409 = conflict = secret already exists if e.status == 409: return await core_api.read_namespaced_secret( name=f"regcred-{run_id}", namespace=namespace ) raise except Exception as e: raise LaunchError(f"Exception when creating Kubernetes secret: {str(e)}\n") def add_wandb_env(root: dict | list, env_vars: dict[str, str]) -> None: """Injects wandb environment variables into specs. Recursively walks the spec and injects the environment variables into every container spec. Containers are identified by the "containers" key. This function treats the WANDB_RUN_ID and WANDB_GROUP_ID environment variables specially. If they are present in the spec, they will be overwritten. If a setting for WANDB_RUN_ID is provided in env_vars, then that environment variable will only be set in the first container modified by this function. Args: root: The spec to modify. env_vars: The environment variables to inject. Returns: None. """ for cont in yield_containers(root): env = cont.setdefault("env", []) env.extend([{"name": key, "value": value} for key, value in env_vars.items()]) cont["env"] = env # After we have set WANDB_RUN_ID once, we don't want to set it again if "WANDB_RUN_ID" in env_vars: env_vars.pop("WANDB_RUN_ID") def yield_pods(manifest: Any) -> Iterator[dict]: """Yield all pod specs in a manifest. Recursively traverses the manifest and yields all pod specs. Pod specs are identified by the presence of a "spec" key with a "containers" key in the value. """ if isinstance(manifest, list): for item in manifest: yield from yield_pods(item) elif isinstance(manifest, dict): if "spec" in manifest and "containers" in manifest["spec"]: yield manifest for value in manifest.values(): if isinstance(value, (dict, list)): yield from yield_pods(value) def add_label_to_pods(manifest: dict | list, label_key: str, label_value: str) -> None: """Add a label to all pod specs in a manifest. Recursively traverses the manifest and adds the label to all pod specs. Pod specs are identified by the presence of a "spec" key with a "containers" key in the value. Args: manifest: The manifest to modify. label_key: The label key to add. label_value: The label value to add. Returns: None. """ for pod in yield_pods(manifest): metadata = pod.setdefault("metadata", {}) labels = metadata.setdefault("labels", {}) labels[label_key] = label_value def add_entrypoint_args_overrides(manifest: dict | list, overrides: dict) -> None: """Add entrypoint and args overrides to all containers in a manifest. Recursively traverses the manifest and adds the entrypoint and args overrides to all containers. Containers are identified by the presence of a "spec" key with a "containers" key in the value. Args: manifest: The manifest to modify. overrides: Dictionary with args and entrypoint keys. Returns: None. """ if isinstance(manifest, list): for item in manifest: add_entrypoint_args_overrides(item, overrides) elif isinstance(manifest, dict): if "spec" in manifest and "containers" in manifest["spec"]: containers = manifest["spec"]["containers"] for container in containers: if "command" in overrides: container["command"] = overrides["command"] if "args" in overrides: container["args"] = overrides["args"] for value in manifest.values(): add_entrypoint_args_overrides(value, overrides) def _set_container_command_with_dep_install( container: dict, working_dir: str, requirements_path: str, ) -> None: """Set a container's command to install dependencies then exec the original command. Replaces command/args with a shell one-liner that installs dependencies, checking in order: 1. requirements.txt (user-provided) 2. pyproject.toml (user-provided, installed via pip install .) 3. requirements.frozen.txt (job artifact fallback) Args: container: The container spec to modify in place. working_dir: The working directory where user dep files are expected. requirements_path: Path to the frozen requirements fallback file. """ original_command = container.get("command", []) original_args = container.get("args", []) original_cmd_str = " ".join( shlex.quote(c) for c in original_command + original_args ) if not original_cmd_str: return user_requirements = f"{working_dir}/requirements.txt" pyproject = f"{working_dir}/pyproject.toml" install_prefix = ( f"if [ -f {shlex.quote(user_requirements)} ]; then" f" pip install -r {shlex.quote(user_requirements)};" f" elif [ -f {shlex.quote(pyproject)} ]; then" f" pip install {shlex.quote(working_dir)};" f" elif [ -f {shlex.quote(requirements_path)} ]; then" f" pip install -r {shlex.quote(requirements_path)};" f" else echo 'No requirements file found'; fi" ) container["command"] = ["/bin/sh", "-c"] container["args"] = [f"{install_prefix} && exec {original_cmd_str}"] def apply_code_mount_configuration( manifest: dict | list, project: LaunchProject ) -> None: """Apply code mount configuration to all containers in a manifest. Recursively traverses the manifest and adds the code mount configuration to all containers. Containers are identified by the presence of a "spec" key with a "containers" key in the value. Args: manifest: The manifest to modify. project: The launch project. Returns: None. """ assert SOURCE_CODE_PVC_NAME is not None source_dir = project.get_image_source_string() for pod in yield_pods(manifest): for container in yield_containers(pod): if "volumeMounts" not in container: container["volumeMounts"] = [] container["volumeMounts"].append( { "name": "wandb-source-code-volume", "mountPath": CODE_MOUNT_DIR, "subPath": source_dir, } ) container["workingDir"] = project.resolved_working_dir if project._auto_default_base_image: _set_container_command_with_dep_install( container, project.resolved_working_dir, f"{CODE_MOUNT_DIR}/.job/requirements.frozen.txt", ) spec = pod["spec"] if "volumes" not in spec: spec["volumes"] = [] spec["volumes"].append( { "name": "wandb-source-code-volume", "persistentVolumeClaim": { "claimName": SOURCE_CODE_PVC_NAME, }, } ) def _build_code_fetch_script( source_type: str, source_info: dict, install_deps: bool, job_dir: str, ) -> str: """Build the shell script for the init container to fetch source code. Args: source_type: Either "artifact" or "repo". source_info: Source metadata from the launch project. install_deps: Whether to also fetch the job artifact for frozen requirements. job_dir: Path where the job artifact should be downloaded. """ job_artifact = source_info.get("job_artifact", "") chmod_suffix = f" && chmod -R a+w {CODE_MOUNT_DIR}/* || true && chmod -R a+w {CODE_MOUNT_DIR}/.* || true" fetch_job_artifact = "" if install_deps and job_artifact: py_cmd = f"import wandb; wandb.Api().artifact({repr(job_artifact)}).download({repr(job_dir)})" fetch_job_artifact = f" && python -c {shlex.quote(py_cmd)}" if source_type == "artifact": artifact_string = source_info.get("artifact_string", "") py_cmd = f"import wandb; wandb.Api().artifact({repr(artifact_string)}).download({repr(CODE_MOUNT_DIR)})" return f"python -c {shlex.quote(py_cmd)}" + fetch_job_artifact + chmod_suffix else: # repo git_remote = source_info.get("git_remote", "") git_commit = source_info.get("git_commit", "") return ( f"git clone {shlex.quote(git_remote)} {CODE_MOUNT_DIR}" f" && git config --global --add safe.directory {CODE_MOUNT_DIR}" f" && cd {CODE_MOUNT_DIR} && git checkout {shlex.quote(git_commit)}" + fetch_job_artifact + chmod_suffix ) def _build_source_init_container( fetch_script: str, base_url: str, api_key_env: dict | None, ) -> dict: """Build the init container spec that fetches source code into the emptyDir volume. Args: fetch_script: Shell script to run in the init container. base_url: W&B base URL passed as an environment variable. api_key_env: Optional WANDB_API_KEY env dict extracted from a main container. """ init_env: list[dict] = [{"name": "WANDB_BASE_URL", "value": base_url}] if api_key_env: init_env.append(api_key_env) return { "name": "wandb-source-code-init", "image": "wandb/launch-agent:latest", "volumeMounts": [ {"name": "wandb-source-code-volume", "mountPath": CODE_MOUNT_DIR} ], "env": init_env, "command": ["/bin/sh", "-c", fetch_script], } def _configure_containers_for_code_mount( pod: dict, project: LaunchProject, install_deps: bool, job_dir: str, ) -> dict | None: """Mount the code volume on all main containers and return the first WANDB_API_KEY env entry. Args: pod: The pod spec dict to modify in place. project: The launch project (for workingDir and dep-install config). install_deps: Whether to wrap container commands with a dep-install step. job_dir: Path to frozen requirements inside the mounted volume. """ api_key_env = None for container in yield_containers(pod): container.setdefault("volumeMounts", []).append( {"name": "wandb-source-code-volume", "mountPath": CODE_MOUNT_DIR} ) container["workingDir"] = project.resolved_working_dir # Only install deps when using the auto-assigned default base image. # User-provided base images are expected to already have deps. if install_deps: _set_container_command_with_dep_install( container, project.resolved_working_dir, f"{job_dir}/requirements.frozen.txt", ) if api_key_env is None: for env in container.get("env", []): if env["name"] == "WANDB_API_KEY": api_key_env = env break return api_key_env def apply_code_mount_configuration_emptydir( manifest: dict | list, project: LaunchProject, api: Api ) -> None: """Apply emptyDir code mount configuration when no PVC is available. Uses an init container to fetch code into an emptyDir volume, which is then mounted into all main containers. Args: manifest: The manifest to modify. project: The launch project. api: The internal API instance (for base_url). """ base_url = api.settings("base_url") source_type = project.job_source_type source_info = project.job_source_info install_deps = project._auto_default_base_image # Validate before mutating the manifest. if source_type not in ("artifact", "repo"): raise LaunchError( f"Cannot use emptyDir code mount for unknown source type: {source_type!r}" ) job_dir = f"{CODE_MOUNT_DIR}/.job" for pod in yield_pods(manifest): spec = pod["spec"] spec.setdefault("volumes", []).append( {"name": "wandb-source-code-volume", "emptyDir": {}} ) api_key_env = _configure_containers_for_code_mount( pod, project, install_deps, job_dir ) fetch_script = _build_code_fetch_script( source_type, source_info, install_deps, job_dir ) init_container = _build_source_init_container( fetch_script, base_url, api_key_env ) spec.setdefault("initContainers", []).append(init_container)