| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481 |
- """Implementation of KubernetesRunner class for wandb launch."""
- from __future__ import annotations
- import asyncio
- import base64
- import datetime
- import json
- import logging
- import os
- import shlex
- import time
- from collections.abc import Iterator
- from typing import Any
- import yaml
- import wandb
- from wandb.apis.internal import Api
- from wandb.sdk.launch.agent.agent import LaunchAgent
- from wandb.sdk.launch.environment.abstract import AbstractEnvironment
- from wandb.sdk.launch.registry.abstract import AbstractRegistry
- from wandb.sdk.launch.registry.azure_container_registry import AzureContainerRegistry
- from wandb.sdk.launch.registry.local_registry import LocalRegistry
- from wandb.sdk.launch.runner.abstract import Status
- from wandb.sdk.launch.runner.kubernetes_monitor import (
- WANDB_K8S_LABEL_AGENT,
- WANDB_K8S_LABEL_AUXILIARY_RESOURCE,
- WANDB_K8S_LABEL_MONITOR,
- WANDB_K8S_RUN_ID,
- CustomResource,
- LaunchKubernetesMonitor,
- )
- from wandb.sdk.launch.utils import (
- recursive_macro_sub,
- sanitize_identifiers_for_k8s,
- yield_containers,
- )
- from wandb.sdk.lib.retry import ExponentialBackoff, retry_async
- from wandb.util import get_module
- from .._project_spec import EntryPoint, LaunchProject
- from ..errors import LaunchError
- from ..utils import (
- CODE_MOUNT_DIR,
- LOG_PREFIX,
- MAX_ENV_LENGTHS,
- PROJECT_SYNCHRONOUS,
- get_kube_context_and_api_client,
- make_k8s_label_safe,
- make_name_dns_safe,
- )
- from .abstract import AbstractRun, AbstractRunner
- get_module(
- "kubernetes_asyncio",
- required="Kubernetes runner requires the kubernetes package. Please install it with `pip install wandb[launch]`.",
- )
- import kubernetes_asyncio # type: ignore # noqa: E402
- from kubernetes_asyncio import client # noqa: E402
- from kubernetes_asyncio.client.api.apps_v1_api import ( # type: ignore # noqa: E402
- AppsV1Api,
- )
- from kubernetes_asyncio.client.api.batch_v1_api import ( # type: ignore # noqa: E402
- BatchV1Api,
- )
- from kubernetes_asyncio.client.api.core_v1_api import ( # type: ignore # noqa: E402
- CoreV1Api,
- )
- from kubernetes_asyncio.client.api.custom_objects_api import ( # type: ignore # noqa: E402
- CustomObjectsApi,
- )
- from kubernetes_asyncio.client.api.networking_v1_api import ( # type: ignore # noqa: E402
- NetworkingV1Api,
- )
- from kubernetes_asyncio.client.models.v1_secret import ( # type: ignore # noqa: E402
- V1Secret,
- )
- from kubernetes_asyncio.client.rest import ApiException # type: ignore # noqa: E402
- TIMEOUT = 5
- API_KEY_SECRET_MAX_RETRIES = 5
- _logger = logging.getLogger(__name__)
- SOURCE_CODE_PVC_MOUNT_PATH = os.environ.get("WANDB_LAUNCH_CODE_PVC_MOUNT_PATH")
- SOURCE_CODE_PVC_NAME = os.environ.get("WANDB_LAUNCH_CODE_PVC_NAME")
- class KubernetesSubmittedRun(AbstractRun):
- """Wrapper for a launched run on Kubernetes."""
- def __init__(
- self,
- batch_api: BatchV1Api,
- core_api: CoreV1Api,
- apps_api: AppsV1Api,
- network_api: NetworkingV1Api,
- name: str,
- namespace: str | None = "default",
- secret: V1Secret | None = None,
- auxiliary_resource_label_key: str | None = None,
- ) -> None:
- """Initialize a KubernetesSubmittedRun.
- Other implementations of the AbstractRun interface poll on the run
- when `get_status` is called, but KubernetesSubmittedRun uses
- Kubernetes watch streams to update the run status. One thread handles
- events from the job object and another thread handles events from the
- rank 0 pod. These threads updated the `_status` attributed of the
- KubernetesSubmittedRun object. When `get_status` is called, the
- `_status` attribute is returned.
- Args:
- batch_api: Kubernetes BatchV1Api object.
- core_api: Kubernetes CoreV1Api object.
- network_api: Kubernetes NetworkV1Api object.
- name: Name of the job.
- namespace: Kubernetes namespace.
- secret: Kubernetes secret.
- Returns:
- None.
- """
- self.batch_api = batch_api
- self.core_api = core_api
- self.apps_api = apps_api
- self.network_api = network_api
- self.name = name
- self.namespace = namespace
- self._fail_count = 0
- self.secret = secret
- self.auxiliary_resource_label_key = auxiliary_resource_label_key
- @property
- def id(self) -> str:
- """Return the run id."""
- return self.name
- async def get_logs(self) -> str | None:
- try:
- pods = await self.core_api.list_namespaced_pod(
- label_selector=f"job-name={self.name}", namespace=self.namespace
- )
- pod_names = [pi.metadata.name for pi in pods.items]
- if not pod_names:
- wandb.termwarn(f"Found no pods for kubernetes job: {self.name}")
- return None
- logs = await self.core_api.read_namespaced_pod_log(
- name=pod_names[0], namespace=self.namespace
- )
- if logs:
- return str(logs)
- else:
- wandb.termwarn(f"No logs for kubernetes pod(s): {pod_names}")
- return None
- except Exception as e:
- wandb.termerror(f"{LOG_PREFIX}Failed to get pod logs: {e}")
- return None
- async def wait(self) -> bool:
- """Wait for the run to finish.
- Returns:
- True if the run finished successfully, False otherwise.
- """
- while True:
- status = await self.get_status()
- wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status.state}")
- if status.state in ["finished", "failed", "preempted"]:
- break
- await asyncio.sleep(5)
- await self._delete_secret()
- return (
- status.state == "finished"
- ) # todo: not sure if this (copied from aws runner) is the right approach? should we return false on failure
- async def get_status(self) -> Status:
- status = LaunchKubernetesMonitor.get_status(self.name)
- if status in ["stopped", "failed", "finished", "preempted"]:
- await self._delete_secret()
- return status
- async def cancel(self) -> None:
- """Cancel the run."""
- try:
- await self.batch_api.delete_namespaced_job(
- namespace=self.namespace,
- name=self.name,
- )
- await self._delete_secret()
- except ApiException as e:
- raise LaunchError(
- f"Failed to delete Kubernetes Job {self.name} in namespace {self.namespace}: {str(e)}"
- ) from e
- async def _delete_secret(self) -> None:
- # Cleanup secret if not running in a helm-managed context
- if not os.environ.get("WANDB_RELEASE_NAME") and self.secret:
- await self.core_api.delete_namespaced_secret(
- name=self.secret.metadata.name,
- namespace=self.secret.metadata.namespace,
- )
- self.secret = None
- async def _delete_auxiliary_resources_by_label(self) -> None:
- if self.auxiliary_resource_label_key is None:
- return
- label_selector = (
- f"{WANDB_K8S_LABEL_AUXILIARY_RESOURCE}={self.auxiliary_resource_label_key}"
- )
- try:
- resource_cleanups = [
- (self.core_api, "service"),
- (self.batch_api, "job"),
- (self.core_api, "pod"),
- (self.core_api, "secret"),
- (self.apps_api, "deployment"),
- (self.network_api, "network_policy"),
- ]
- for api_client, resource_type in resource_cleanups:
- try:
- list_method = getattr(
- api_client, f"list_namespaced_{resource_type}"
- )
- delete_method = getattr(
- api_client, f"delete_namespaced_{resource_type}"
- )
- # List resources with our label
- resources = await list_method(
- namespace=self.namespace, label_selector=label_selector
- )
- # Delete each resource
- for resource in resources.items:
- await delete_method(
- name=resource.metadata.name, namespace=self.namespace
- )
- except (AttributeError, ApiException) as e:
- wandb.termwarn(f"Could not clean up {resource_type}: {e}")
- except Exception as e:
- wandb.termwarn(f"Failed to clean up some auxiliary resources: {e}")
- class CrdSubmittedRun(AbstractRun):
- """Run submitted to a CRD backend, e.g. Volcano."""
- def __init__(
- self,
- group: str,
- version: str,
- plural: str,
- name: str,
- namespace: str,
- core_api: CoreV1Api,
- custom_api: CustomObjectsApi,
- ) -> None:
- """Create a run object for tracking the progress of a CRD.
- Args:
- group: The API group of the CRD.
- version: The API version of the CRD.
- plural: The plural name of the CRD.
- name: The name of the CRD instance.
- namespace: The namespace of the CRD instance.
- core_api: The Kubernetes core API client.
- custom_api: The Kubernetes custom object API client.
- Raises:
- LaunchError: If the CRD instance does not exist.
- """
- self.group = group
- self.version = version
- self.plural = plural
- self.name = name
- self.namespace = namespace
- self.core_api = core_api
- self.custom_api = custom_api
- self._fail_count = 0
- @property
- def id(self) -> str:
- """Get the name of the custom object."""
- return self.name
- async def get_logs(self) -> str | None:
- """Get logs for custom object."""
- # TODO: test more carefully once we release multi-node support
- logs: dict[str, str | None] = {}
- try:
- pods = await self.core_api.list_namespaced_pod(
- label_selector=f"wandb/run-id={self.name}", namespace=self.namespace
- )
- pod_names = [pi.metadata.name for pi in pods.items]
- for pod_name in pod_names:
- logs[pod_name] = await self.core_api.read_namespaced_pod_log(
- name=pod_name, namespace=self.namespace
- )
- except ApiException as e:
- wandb.termwarn(f"Failed to get logs for {self.name}: {str(e)}")
- return None
- if not logs:
- return None
- logs_as_array = [f"Pod {pod_name}:\n{log}" for pod_name, log in logs.items()]
- return "\n".join(logs_as_array)
- async def get_status(self) -> Status:
- """Get status of custom object."""
- return LaunchKubernetesMonitor.get_status(self.name)
- async def cancel(self) -> None:
- """Cancel the custom object."""
- try:
- await self.custom_api.delete_namespaced_custom_object(
- group=self.group,
- version=self.version,
- namespace=self.namespace,
- plural=self.plural,
- name=self.name,
- )
- except ApiException as e:
- raise LaunchError(
- f"Failed to delete CRD {self.name} in namespace {self.namespace}: {str(e)}"
- ) from e
- async def wait(self) -> bool:
- """Wait for this custom object to finish running."""
- while True:
- status = await self.get_status()
- wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
- if status.state in ["finished", "failed", "preempted"]:
- return status.state == "finished"
- await asyncio.sleep(5)
- class KubernetesRunner(AbstractRunner):
- """Launches runs onto kubernetes."""
- def __init__(
- self,
- api: Api,
- backend_config: dict[str, Any],
- environment: AbstractEnvironment,
- registry: AbstractRegistry,
- ) -> None:
- """Create a Kubernetes runner.
- Args:
- api: The API client object.
- backend_config: The backend configuration.
- environment: The environment to launch runs into.
- Raises:
- LaunchError: If the Kubernetes configuration is invalid.
- """
- super().__init__(api, backend_config)
- self.environment = environment
- self.registry = registry
- def get_namespace(
- self, resource_args: dict[str, Any], context: dict[str, Any]
- ) -> str:
- """Get the namespace to launch into.
- Args:
- resource_args: The resource args to launch.
- context: The k8s config context.
- Returns:
- The namespace to launch into.
- """
- default_namespace = (
- context["context"].get("namespace", "default") if context else "default"
- )
- return ( # type: ignore[no-any-return]
- resource_args.get("metadata", {}).get("namespace")
- or resource_args.get(
- "namespace"
- ) # continue support for malformed namespace
- or self.backend_config.get("runner", {}).get("namespace")
- or default_namespace
- )
- async def _inject_defaults(
- self,
- resource_args: dict[str, Any],
- launch_project: LaunchProject,
- image_uri: str,
- namespace: str,
- core_api: CoreV1Api,
- ) -> tuple[dict[str, Any], V1Secret | None]:
- """Apply our default values, return job dict and api key secret.
- Args:
- resource_args (Dict[str, Any]): The resource args to launch.
- launch_project (LaunchProject): The launch project.
- builder (Optional[AbstractBuilder]): The builder.
- namespace (str): The namespace.
- core_api (CoreV1Api): The core api.
- Returns:
- Tuple[Dict[str, Any], Optional["V1Secret"]]: The resource args and api key secret.
- """
- job: dict[str, Any] = {
- "apiVersion": "batch/v1",
- "kind": "Job",
- }
- job.update(resource_args)
- job_metadata: dict[str, Any] = job.get("metadata", {})
- job_spec: dict[str, Any] = {"backoffLimit": 0, "ttlSecondsAfterFinished": 60}
- job_spec.update(job.get("spec", {}))
- pod_template: dict[str, Any] = job_spec.get("template", {})
- pod_spec: dict[str, Any] = {"restartPolicy": "Never"}
- pod_spec.update(pod_template.get("spec", {}))
- containers: list[dict[str, Any]] = pod_spec.get("containers", [{}])
- # Add labels to job metadata
- job_metadata.setdefault("labels", {})
- job_metadata["labels"][WANDB_K8S_RUN_ID] = launch_project.run_id
- job_metadata["labels"][WANDB_K8S_LABEL_MONITOR] = "true"
- if LaunchAgent.initialized():
- job_metadata["labels"][WANDB_K8S_LABEL_AGENT] = LaunchAgent.name()
- # name precedence: name in spec > generated name
- if not job_metadata.get("name"):
- job_metadata["generateName"] = make_name_dns_safe(
- f"launch-{launch_project.target_entity}-{launch_project.target_project}-"
- )
- job_metadata["namespace"] = namespace
- for i, cont in enumerate(containers):
- if "name" not in cont:
- cont["name"] = cont.get("name", "launch" + str(i))
- if "securityContext" not in cont:
- cont["securityContext"] = {
- "allowPrivilegeEscalation": False,
- "capabilities": {"drop": ["ALL"]},
- "seccompProfile": {"type": "RuntimeDefault"},
- }
- entry_point = (
- launch_project.override_entrypoint or launch_project.get_job_entry_point()
- )
- if launch_project.docker_image:
- # dont specify run id if user provided image, could have multiple runs
- containers[0]["image"] = image_uri
- # TODO: handle secret pulling image from registry
- elif not any(["image" in cont for cont in containers]):
- assert entry_point is not None
- # in the non instance case we need to make an imagePullSecret
- # so the new job can pull the image
- containers[0]["image"] = image_uri
- secret = await maybe_create_imagepull_secret(
- core_api, self.registry, launch_project.run_id, namespace
- )
- if secret is not None:
- pod_spec["imagePullSecrets"] = [
- {"name": f"regcred-{launch_project.run_id}"}
- ]
- inject_entrypoint_and_args(
- containers,
- entry_point,
- launch_project.override_args,
- launch_project.override_entrypoint is not None,
- )
- env_vars = launch_project.get_env_vars_dict(
- self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
- )
- api_key_secret = None
- for cont in containers:
- # Add our env vars to user supplied env vars
- env = cont.get("env") or []
- for key, value in env_vars.items():
- if (
- key == "WANDB_API_KEY"
- and value
- and (
- LaunchAgent.initialized()
- or self.backend_config[PROJECT_SYNCHRONOUS]
- )
- ):
- # Override API key with secret. TODO: Do the same for other runners
- release_name = os.environ.get("WANDB_RELEASE_NAME")
- secret_name = "wandb-api-key"
- if release_name:
- secret_name += f"-{release_name}"
- else:
- secret_name += f"-{launch_project.run_id}"
- def handle_exception(e):
- wandb.termwarn(
- f"Exception when ensuring Kubernetes API key secret: {e}. Retrying..."
- )
- api_key_secret = await retry_async(
- backoff=ExponentialBackoff(
- initial_sleep=datetime.timedelta(seconds=1),
- max_sleep=datetime.timedelta(minutes=1),
- max_retries=API_KEY_SECRET_MAX_RETRIES,
- ),
- fn=ensure_api_key_secret,
- on_exc=handle_exception,
- core_api=core_api,
- secret_name=secret_name,
- namespace=namespace,
- api_key=value,
- )
- env.append(
- {
- "name": key,
- "valueFrom": {
- "secretKeyRef": {
- "name": secret_name,
- "key": "password",
- }
- },
- }
- )
- else:
- env.append({"name": key, "value": value})
- cont["env"] = env
- pod_spec["containers"] = containers
- pod_template["spec"] = pod_spec
- job_spec["template"] = pod_template
- job["spec"] = job_spec
- job["metadata"] = job_metadata
- add_label_to_pods(
- job,
- WANDB_K8S_LABEL_MONITOR,
- "true",
- )
- if launch_project.job_base_image:
- if SOURCE_CODE_PVC_NAME and SOURCE_CODE_PVC_MOUNT_PATH:
- apply_code_mount_configuration(
- job,
- launch_project,
- )
- else:
- apply_code_mount_configuration_emptydir(
- job,
- launch_project,
- self._api,
- )
- # Add wandb.ai/agent: current agent label on all pods
- if LaunchAgent.initialized():
- add_label_to_pods(
- job,
- WANDB_K8S_LABEL_AGENT,
- LaunchAgent.name(),
- )
- return job, api_key_secret
- async def _wait_for_resource_ready(
- self,
- api_client: kubernetes_asyncio.client.ApiClient,
- config: dict[str, Any],
- namespace: str,
- timeout_seconds: int = 300,
- ) -> None:
- """Wait for a Kubernetes resource to be ready.
- Args:
- api_client: The Kubernetes API client.
- config: The resource configuration.
- namespace: The namespace where the resource was created.
- timeout_seconds: Maximum time to wait for readiness.
- """
- resource_kind = config.get("kind")
- resource_name = config.get("metadata", {}).get("name")
- if not resource_kind or not resource_name:
- wandb.termerror(
- f"{LOG_PREFIX}Cannot wait for resource without kind or name"
- )
- return
- wandb.termlog(
- f"{LOG_PREFIX}Waiting for {resource_kind} '{resource_name}' to be ready..."
- )
- start_time = time.time()
- if resource_kind == "Deployment":
- await self._wait_for_deployment_ready(
- api_client, resource_name, namespace, timeout_seconds
- )
- elif resource_kind == "Service":
- await self._wait_for_service_ready(
- api_client, resource_name, namespace, timeout_seconds
- )
- elif resource_kind == "Pod":
- await self._wait_for_pod_ready(
- api_client, resource_name, namespace, timeout_seconds
- )
- else:
- wandb.termlog(
- f"{LOG_PREFIX}No specific readiness check for {resource_kind}, waiting 5 seconds..."
- )
- await asyncio.sleep(5)
- elapsed = time.time() - start_time
- wandb.termlog(
- f"{LOG_PREFIX}{resource_kind} '{resource_name}' is ready after {elapsed:.1f}s"
- )
- async def _wait_for_deployment_ready(
- self,
- api_client: kubernetes_asyncio.client.ApiClient,
- name: str,
- namespace: str,
- timeout_seconds: int,
- ) -> None:
- """Wait for a Deployment to be ready."""
- apps_api = kubernetes_asyncio.client.AppsV1Api(api_client)
- async def check_deployment_ready():
- deployment = await apps_api.read_namespaced_deployment(
- name=name, namespace=namespace
- )
- status = deployment.status
- if status.ready_replicas and status.replicas:
- return status.ready_replicas >= status.replicas
- return False
- await self._wait_with_timeout(check_deployment_ready, timeout_seconds, name)
- async def _wait_for_service_ready(
- self,
- api_client: kubernetes_asyncio.client.ApiClient,
- name: str,
- namespace: str,
- timeout_seconds: int,
- ) -> None:
- """Wait for a Service to have endpoints."""
- core_api = kubernetes_asyncio.client.CoreV1Api(api_client)
- async def check_service_ready():
- endpoints = await core_api.read_namespaced_endpoints(
- name=name, namespace=namespace
- )
- if endpoints.subsets:
- for subset in endpoints.subsets:
- if subset.addresses: # These are ready pod addresses
- return True
- return False
- await self._wait_with_timeout(check_service_ready, timeout_seconds, name)
- async def _wait_for_pod_ready(
- self,
- api_client: kubernetes_asyncio.client.ApiClient,
- name: str,
- namespace: str,
- timeout_seconds: int,
- ) -> None:
- """Wait for a Pod to be ready."""
- core_api = kubernetes_asyncio.client.CoreV1Api(api_client)
- async def check_pod_ready():
- pod = await core_api.read_namespaced_pod(name=name, namespace=namespace)
- if pod.status.phase == "Running":
- if pod.status.container_statuses:
- return all(status.ready for status in pod.status.container_statuses)
- return True
- return False
- await self._wait_with_timeout(check_pod_ready, timeout_seconds, name)
- async def _wait_with_timeout(
- self, check_func, timeout_seconds: int, name: str
- ) -> None:
- """Generic timeout wrapper for readiness checks."""
- start_time = time.time()
- while time.time() - start_time < timeout_seconds:
- try:
- if await check_func():
- return
- except kubernetes_asyncio.client.ApiException as e:
- if e.status == 404:
- pass
- else:
- wandb.termerror(
- f"{LOG_PREFIX}Error waiting for resource '{name}': {e}"
- )
- raise
- except Exception as e:
- wandb.termerror(f"{LOG_PREFIX}Error waiting for resource '{name}': {e}")
- raise
- await asyncio.sleep(2)
- raise LaunchError(
- f"Resource '{name}' not ready within {timeout_seconds} seconds"
- )
- async def _prepare_resource(
- self,
- api_client: kubernetes_asyncio.client.ApiClient,
- config: dict[str, Any],
- namespace: str,
- run_id: str,
- launch_project: LaunchProject,
- api_key_secret: V1Secret | None = None,
- wait_for_ready: bool = True,
- wait_timeout: int = 300,
- auxiliary_resource_label_value: str | None = None,
- ) -> None:
- """Prepare a service for launch.
- Args:
- api_client: The Kubernetes API client.
- config: The resource configuration to prepare.
- namespace: The namespace to create the resource in.
- run_id: The run ID to label the resource with.
- launch_project: The launch project to get environment variables from.
- api_key_secret: The API key secret to inject.
- wait_for_ready: Whether to wait for the resource to be ready after creation.
- wait_timeout: Maximum time in seconds to wait for resource readiness.
- """
- config.setdefault("metadata", {})
- config["metadata"].setdefault("labels", {})
- config["metadata"]["labels"][WANDB_K8S_RUN_ID] = run_id
- config["metadata"]["labels"]["wandb.ai/created-by"] = "launch-agent"
- if auxiliary_resource_label_value:
- config["metadata"]["labels"][WANDB_K8S_LABEL_AUXILIARY_RESOURCE] = (
- auxiliary_resource_label_value
- )
- env_vars = launch_project.get_env_vars_dict(
- self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
- )
- wandb_config_env = {
- "WANDB_CONFIG": env_vars.get("WANDB_CONFIG", "{}"),
- }
- add_wandb_env(config, wandb_config_env)
- if auxiliary_resource_label_value:
- add_label_to_pods(
- config,
- WANDB_K8S_LABEL_AUXILIARY_RESOURCE,
- auxiliary_resource_label_value,
- )
- if api_key_secret:
- for cont in yield_containers(config):
- env = cont.setdefault("env", [])
- env.append(
- {
- "name": "WANDB_API_KEY",
- "valueFrom": {
- "secretKeyRef": {
- "name": api_key_secret.metadata.name,
- "key": "password",
- }
- },
- }
- )
- cont["env"] = env
- try:
- sanitize_identifiers_for_k8s(config)
- await kubernetes_asyncio.utils.create_from_dict(
- api_client, config, namespace=namespace
- )
- if wait_for_ready:
- await self._wait_for_resource_ready(
- api_client, config, namespace, wait_timeout
- )
- except Exception as e:
- wandb.termerror(f"{LOG_PREFIX}Failed to create Kubernetes resource: {e}")
- raise LaunchError(f"Failed to create Kubernetes resource: {e}")
- async def run(
- self, launch_project: LaunchProject, image_uri: str
- ) -> AbstractRun | None:
- """Execute a launch project on Kubernetes.
- Args:
- launch_project: The launch project to execute.
- builder: The builder to use to build the image.
- Returns:
- The run object if the run was successful, otherwise None.
- """
- await LaunchKubernetesMonitor.ensure_initialized()
- resource_args = launch_project.fill_macros(image_uri).get("kubernetes", {})
- if not resource_args:
- wandb.termlog(
- f"{LOG_PREFIX}Note: no resource args specified. Add a "
- "Kubernetes yaml spec or other options in a json file "
- "with --resource-args <json>."
- )
- _logger.info(f"Running Kubernetes job with resource args: {resource_args}")
- context, api_client = await get_kube_context_and_api_client(
- kubernetes_asyncio, resource_args
- )
- # If using pvc for code mount, move code there.
- use_emptydir_code_mount = False
- if launch_project.job_base_image is not None:
- if SOURCE_CODE_PVC_NAME and SOURCE_CODE_PVC_MOUNT_PATH:
- code_subdir = launch_project.get_image_source_string()
- launch_project.change_project_dir(
- os.path.join(SOURCE_CODE_PVC_MOUNT_PATH, code_subdir)
- )
- else:
- use_emptydir_code_mount = True
- # If the user specified an alternate api, we need will execute this
- # run by creating a custom object.
- api_version = resource_args.get("apiVersion", "batch/v1")
- if api_version not in ["batch/v1", "batch/v1beta1"]:
- env_vars = launch_project.get_env_vars_dict(
- self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
- )
- # Crawl the resource args and add our env vars to the containers.
- add_wandb_env(resource_args, env_vars)
- # Add our labels to the resource args. This is necessary for the
- # agent to find the custom object later on.
- resource_args["metadata"] = resource_args.get("metadata", {})
- resource_args["metadata"]["labels"] = resource_args["metadata"].get(
- "labels", {}
- )
- resource_args["metadata"]["labels"][WANDB_K8S_LABEL_MONITOR] = "true"
- # Crawl the resource arsg and add our labels to the pods. This is
- # necessary for the agent to find the pods later on.
- add_label_to_pods(
- resource_args,
- WANDB_K8S_LABEL_MONITOR,
- "true",
- )
- # Add wandb.ai/agent: current agent label on all pods
- if LaunchAgent.initialized():
- add_label_to_pods(
- resource_args,
- WANDB_K8S_LABEL_AGENT,
- LaunchAgent.name(),
- )
- resource_args["metadata"]["labels"][WANDB_K8S_LABEL_AGENT] = (
- LaunchAgent.name()
- )
- if launch_project.job_base_image:
- if use_emptydir_code_mount:
- apply_code_mount_configuration_emptydir(
- resource_args, launch_project, self._api
- )
- else:
- apply_code_mount_configuration(resource_args, launch_project)
- overrides = {}
- if launch_project.override_args:
- overrides["args"] = launch_project.override_args
- if launch_project.override_entrypoint:
- overrides["command"] = launch_project.override_entrypoint.command
- add_entrypoint_args_overrides(
- resource_args,
- overrides,
- )
- api = client.CustomObjectsApi(api_client)
- # Infer the attributes of a custom object from the apiVersion and/or
- # a kind: attribute in the resource args.
- namespace = self.get_namespace(resource_args, context)
- group, version, *_ = api_version.split("/")
- group = resource_args.get("group", group)
- version = resource_args.get("version", version)
- kind = resource_args.get("kind", version)
- plural = f"{kind.lower()}s"
- custom_resource = CustomResource(
- group=group,
- version=version,
- plural=plural,
- )
- LaunchKubernetesMonitor.monitor_namespace(
- namespace, custom_resource=custom_resource
- )
- try:
- response = await api.create_namespaced_custom_object(
- group=group,
- version=version,
- namespace=namespace,
- plural=plural,
- body=resource_args,
- )
- except ApiException as e:
- body = json.loads(e.body)
- body_yaml = yaml.dump(body)
- raise LaunchError(
- f"Error creating CRD of kind {kind}: {e.status} {e.reason}\n{body_yaml}"
- ) from e
- name = response.get("metadata", {}).get("name")
- _logger.info(f"Created {kind} {response['metadata']['name']}")
- submitted_run = CrdSubmittedRun(
- name=name,
- group=group,
- version=version,
- namespace=namespace,
- plural=plural,
- core_api=client.CoreV1Api(api_client),
- custom_api=api,
- )
- if self.backend_config[PROJECT_SYNCHRONOUS]:
- await submitted_run.wait()
- return submitted_run
- batch_api = kubernetes_asyncio.client.BatchV1Api(api_client)
- core_api = kubernetes_asyncio.client.CoreV1Api(api_client)
- apps_api = kubernetes_asyncio.client.AppsV1Api(api_client)
- network_api = kubernetes_asyncio.client.NetworkingV1Api(api_client)
- namespace = self.get_namespace(resource_args, context)
- job, secret = await self._inject_defaults(
- resource_args, launch_project, image_uri, namespace, core_api
- )
- update_dict = {
- "project_name": launch_project.target_project,
- "entity_name": launch_project.target_entity,
- "run_id": launch_project.run_id,
- "run_name": launch_project.name,
- "image_uri": image_uri,
- "author": launch_project.author,
- }
- update_dict.update(os.environ)
- additional_services: list[dict[str, Any]] = recursive_macro_sub(
- launch_project.launch_spec.get("additional_services", []), update_dict
- )
- auxiliary_resource_label_value = make_k8s_label_safe(
- f"aux-{launch_project.target_entity}-{launch_project.target_project}-{launch_project.run_id}"
- )
- if additional_services:
- wandb.termlog(
- f"{LOG_PREFIX}Creating additional services: {additional_services}"
- )
- wait_for_ready = resource_args.get("wait_for_ready", True)
- wait_timeout = resource_args.get("wait_timeout", 300)
- await asyncio.gather(
- *[
- self._prepare_resource(
- api_client,
- resource.get("config", {}),
- namespace,
- launch_project.run_id,
- launch_project,
- secret,
- wait_for_ready,
- wait_timeout,
- auxiliary_resource_label_value,
- )
- for resource in additional_services
- if resource.get("config", {})
- ]
- )
- msg = "Creating Kubernetes job"
- if "name" in resource_args:
- msg += f": {resource_args['name']}"
- _logger.info(msg)
- try:
- response = await kubernetes_asyncio.utils.create_from_dict(
- api_client, job, namespace=namespace
- )
- except kubernetes_asyncio.utils.FailToCreateError as e:
- for exc in e.api_exceptions:
- resp = json.loads(exc.body)
- msg = resp.get("message")
- code = resp.get("code")
- raise LaunchError(
- f"Failed to create Kubernetes job for run {launch_project.run_id} ({code} {exc.reason}): {msg}"
- )
- except Exception as e:
- raise LaunchError(
- f"Unexpected exception when creating Kubernetes job: {str(e)}\n"
- )
- job_response = response[0]
- job_name = job_response.metadata.name
- LaunchKubernetesMonitor.monitor_namespace(namespace)
- submitted_job = KubernetesSubmittedRun(
- batch_api,
- core_api,
- apps_api,
- network_api,
- job_name,
- namespace,
- secret,
- auxiliary_resource_label_value,
- )
- if self.backend_config[PROJECT_SYNCHRONOUS]:
- await submitted_job.wait()
- return submitted_job
- def inject_entrypoint_and_args(
- containers: list[dict],
- entry_point: EntryPoint | None,
- override_args: list[str],
- should_override_entrypoint: bool,
- ) -> None:
- """Inject the entrypoint and args into the containers.
- Args:
- containers: The containers to inject the entrypoint and args into.
- entry_point: The entrypoint to inject.
- override_args: The args to inject.
- should_override_entrypoint: Whether to override the entrypoint.
- Returns:
- None
- """
- for i in range(len(containers)):
- if override_args:
- containers[i]["args"] = override_args
- if entry_point and (
- not containers[i].get("command") or should_override_entrypoint
- ):
- containers[i]["command"] = entry_point.command
- async def ensure_api_key_secret(
- core_api: CoreV1Api,
- secret_name: str,
- namespace: str,
- api_key: str,
- ) -> V1Secret:
- """Create a secret containing a user's wandb API key.
- Args:
- core_api: The Kubernetes CoreV1Api object.
- secret_name: The name to use for the secret.
- namespace: The namespace to create the secret in.
- api_key: The user's wandb API key
- Returns:
- The created secret
- """
- secret_data = {"password": base64.b64encode(api_key.encode()).decode()}
- labels = {"wandb.ai/created-by": "launch-agent"}
- secret = client.V1Secret(
- data=secret_data,
- metadata=client.V1ObjectMeta(
- name=secret_name, namespace=namespace, labels=labels
- ),
- kind="Secret",
- type="kubernetes.io/basic-auth",
- )
- try:
- try:
- return await core_api.create_namespaced_secret(namespace, secret)
- except ApiException as e:
- # 409 = conflict = secret already exists
- if e.status == 409:
- existing_secret = await core_api.read_namespaced_secret(
- name=secret_name, namespace=namespace
- )
- if existing_secret.data != secret_data:
- # If it's a previous secret made by launch agent, clean it up
- if (
- existing_secret.metadata.labels.get("wandb.ai/created-by")
- == "launch-agent"
- ):
- await core_api.delete_namespaced_secret(
- name=secret_name, namespace=namespace
- )
- return await core_api.create_namespaced_secret(
- namespace, secret
- )
- else:
- raise LaunchError(
- f"Kubernetes secret already exists in namespace {namespace} with incorrect data: {secret_name}"
- )
- return existing_secret
- raise
- except Exception as e:
- raise LaunchError(
- f"Exception when ensuring Kubernetes API key secret: {str(e)}\n"
- )
- async def maybe_create_imagepull_secret(
- core_api: CoreV1Api,
- registry: AbstractRegistry,
- run_id: str,
- namespace: str,
- ) -> V1Secret | None:
- """Create a secret for pulling images from a private registry.
- Args:
- core_api: The Kubernetes CoreV1Api object.
- registry: The registry to pull from.
- run_id: The run id.
- namespace: The namespace to create the secret in.
- Returns:
- A secret if one was created, otherwise None.
- """
- secret = None
- if isinstance(registry, (LocalRegistry, AzureContainerRegistry)):
- # Secret not required
- return None
- uname, token = await registry.get_username_password()
- creds_info = {
- "auths": {
- registry.uri: {
- "auth": base64.b64encode(f"{uname}:{token}".encode()).decode(),
- # need an email but the use is deprecated
- "email": "deprecated@wandblaunch.com",
- }
- }
- }
- secret_data = {
- ".dockerconfigjson": base64.b64encode(json.dumps(creds_info).encode()).decode()
- }
- secret = client.V1Secret(
- data=secret_data,
- metadata=client.V1ObjectMeta(name=f"regcred-{run_id}", namespace=namespace),
- kind="Secret",
- type="kubernetes.io/dockerconfigjson",
- )
- try:
- try:
- return await core_api.create_namespaced_secret(namespace, secret)
- except ApiException as e:
- # 409 = conflict = secret already exists
- if e.status == 409:
- return await core_api.read_namespaced_secret(
- name=f"regcred-{run_id}", namespace=namespace
- )
- raise
- except Exception as e:
- raise LaunchError(f"Exception when creating Kubernetes secret: {str(e)}\n")
- def add_wandb_env(root: dict | list, env_vars: dict[str, str]) -> None:
- """Injects wandb environment variables into specs.
- Recursively walks the spec and injects the environment variables into
- every container spec. Containers are identified by the "containers" key.
- This function treats the WANDB_RUN_ID and WANDB_GROUP_ID environment variables
- specially. If they are present in the spec, they will be overwritten. If a setting
- for WANDB_RUN_ID is provided in env_vars, then that environment variable will only be
- set in the first container modified by this function.
- Args:
- root: The spec to modify.
- env_vars: The environment variables to inject.
- Returns: None.
- """
- for cont in yield_containers(root):
- env = cont.setdefault("env", [])
- env.extend([{"name": key, "value": value} for key, value in env_vars.items()])
- cont["env"] = env
- # After we have set WANDB_RUN_ID once, we don't want to set it again
- if "WANDB_RUN_ID" in env_vars:
- env_vars.pop("WANDB_RUN_ID")
- def yield_pods(manifest: Any) -> Iterator[dict]:
- """Yield all pod specs in a manifest.
- Recursively traverses the manifest and yields all pod specs. Pod specs are
- identified by the presence of a "spec" key with a "containers" key in the
- value.
- """
- if isinstance(manifest, list):
- for item in manifest:
- yield from yield_pods(item)
- elif isinstance(manifest, dict):
- if "spec" in manifest and "containers" in manifest["spec"]:
- yield manifest
- for value in manifest.values():
- if isinstance(value, (dict, list)):
- yield from yield_pods(value)
- def add_label_to_pods(manifest: dict | list, label_key: str, label_value: str) -> None:
- """Add a label to all pod specs in a manifest.
- Recursively traverses the manifest and adds the label to all pod specs.
- Pod specs are identified by the presence of a "spec" key with a "containers"
- key in the value.
- Args:
- manifest: The manifest to modify.
- label_key: The label key to add.
- label_value: The label value to add.
- Returns: None.
- """
- for pod in yield_pods(manifest):
- metadata = pod.setdefault("metadata", {})
- labels = metadata.setdefault("labels", {})
- labels[label_key] = label_value
- def add_entrypoint_args_overrides(manifest: dict | list, overrides: dict) -> None:
- """Add entrypoint and args overrides to all containers in a manifest.
- Recursively traverses the manifest and adds the entrypoint and args overrides
- to all containers. Containers are identified by the presence of a "spec" key
- with a "containers" key in the value.
- Args:
- manifest: The manifest to modify.
- overrides: Dictionary with args and entrypoint keys.
- Returns: None.
- """
- if isinstance(manifest, list):
- for item in manifest:
- add_entrypoint_args_overrides(item, overrides)
- elif isinstance(manifest, dict):
- if "spec" in manifest and "containers" in manifest["spec"]:
- containers = manifest["spec"]["containers"]
- for container in containers:
- if "command" in overrides:
- container["command"] = overrides["command"]
- if "args" in overrides:
- container["args"] = overrides["args"]
- for value in manifest.values():
- add_entrypoint_args_overrides(value, overrides)
- def _set_container_command_with_dep_install(
- container: dict,
- working_dir: str,
- requirements_path: str,
- ) -> None:
- """Set a container's command to install dependencies then exec the original command.
- Replaces command/args with a shell one-liner that installs dependencies, checking in order:
- 1. requirements.txt (user-provided)
- 2. pyproject.toml (user-provided, installed via pip install .)
- 3. requirements.frozen.txt (job artifact fallback)
- Args:
- container: The container spec to modify in place.
- working_dir: The working directory where user dep files are expected.
- requirements_path: Path to the frozen requirements fallback file.
- """
- original_command = container.get("command", [])
- original_args = container.get("args", [])
- original_cmd_str = " ".join(
- shlex.quote(c) for c in original_command + original_args
- )
- if not original_cmd_str:
- return
- user_requirements = f"{working_dir}/requirements.txt"
- pyproject = f"{working_dir}/pyproject.toml"
- install_prefix = (
- f"if [ -f {shlex.quote(user_requirements)} ]; then"
- f" pip install -r {shlex.quote(user_requirements)};"
- f" elif [ -f {shlex.quote(pyproject)} ]; then"
- f" pip install {shlex.quote(working_dir)};"
- f" elif [ -f {shlex.quote(requirements_path)} ]; then"
- f" pip install -r {shlex.quote(requirements_path)};"
- f" else echo 'No requirements file found'; fi"
- )
- container["command"] = ["/bin/sh", "-c"]
- container["args"] = [f"{install_prefix} && exec {original_cmd_str}"]
- def apply_code_mount_configuration(
- manifest: dict | list, project: LaunchProject
- ) -> None:
- """Apply code mount configuration to all containers in a manifest.
- Recursively traverses the manifest and adds the code mount configuration to
- all containers. Containers are identified by the presence of a "spec" key
- with a "containers" key in the value.
- Args:
- manifest: The manifest to modify.
- project: The launch project.
- Returns: None.
- """
- assert SOURCE_CODE_PVC_NAME is not None
- source_dir = project.get_image_source_string()
- for pod in yield_pods(manifest):
- for container in yield_containers(pod):
- if "volumeMounts" not in container:
- container["volumeMounts"] = []
- container["volumeMounts"].append(
- {
- "name": "wandb-source-code-volume",
- "mountPath": CODE_MOUNT_DIR,
- "subPath": source_dir,
- }
- )
- container["workingDir"] = project.resolved_working_dir
- if project._auto_default_base_image:
- _set_container_command_with_dep_install(
- container,
- project.resolved_working_dir,
- f"{CODE_MOUNT_DIR}/.job/requirements.frozen.txt",
- )
- spec = pod["spec"]
- if "volumes" not in spec:
- spec["volumes"] = []
- spec["volumes"].append(
- {
- "name": "wandb-source-code-volume",
- "persistentVolumeClaim": {
- "claimName": SOURCE_CODE_PVC_NAME,
- },
- }
- )
- def _build_code_fetch_script(
- source_type: str,
- source_info: dict,
- install_deps: bool,
- job_dir: str,
- ) -> str:
- """Build the shell script for the init container to fetch source code.
- Args:
- source_type: Either "artifact" or "repo".
- source_info: Source metadata from the launch project.
- install_deps: Whether to also fetch the job artifact for frozen requirements.
- job_dir: Path where the job artifact should be downloaded.
- """
- job_artifact = source_info.get("job_artifact", "")
- chmod_suffix = f" && chmod -R a+w {CODE_MOUNT_DIR}/* || true && chmod -R a+w {CODE_MOUNT_DIR}/.* || true"
- fetch_job_artifact = ""
- if install_deps and job_artifact:
- py_cmd = f"import wandb; wandb.Api().artifact({repr(job_artifact)}).download({repr(job_dir)})"
- fetch_job_artifact = f" && python -c {shlex.quote(py_cmd)}"
- if source_type == "artifact":
- artifact_string = source_info.get("artifact_string", "")
- py_cmd = f"import wandb; wandb.Api().artifact({repr(artifact_string)}).download({repr(CODE_MOUNT_DIR)})"
- return f"python -c {shlex.quote(py_cmd)}" + fetch_job_artifact + chmod_suffix
- else: # repo
- git_remote = source_info.get("git_remote", "")
- git_commit = source_info.get("git_commit", "")
- return (
- f"git clone {shlex.quote(git_remote)} {CODE_MOUNT_DIR}"
- f" && git config --global --add safe.directory {CODE_MOUNT_DIR}"
- f" && cd {CODE_MOUNT_DIR} && git checkout {shlex.quote(git_commit)}"
- + fetch_job_artifact
- + chmod_suffix
- )
- def _build_source_init_container(
- fetch_script: str,
- base_url: str,
- api_key_env: dict | None,
- ) -> dict:
- """Build the init container spec that fetches source code into the emptyDir volume.
- Args:
- fetch_script: Shell script to run in the init container.
- base_url: W&B base URL passed as an environment variable.
- api_key_env: Optional WANDB_API_KEY env dict extracted from a main container.
- """
- init_env: list[dict] = [{"name": "WANDB_BASE_URL", "value": base_url}]
- if api_key_env:
- init_env.append(api_key_env)
- return {
- "name": "wandb-source-code-init",
- "image": "wandb/launch-agent:latest",
- "volumeMounts": [
- {"name": "wandb-source-code-volume", "mountPath": CODE_MOUNT_DIR}
- ],
- "env": init_env,
- "command": ["/bin/sh", "-c", fetch_script],
- }
- def _configure_containers_for_code_mount(
- pod: dict,
- project: LaunchProject,
- install_deps: bool,
- job_dir: str,
- ) -> dict | None:
- """Mount the code volume on all main containers and return the first WANDB_API_KEY env entry.
- Args:
- pod: The pod spec dict to modify in place.
- project: The launch project (for workingDir and dep-install config).
- install_deps: Whether to wrap container commands with a dep-install step.
- job_dir: Path to frozen requirements inside the mounted volume.
- """
- api_key_env = None
- for container in yield_containers(pod):
- container.setdefault("volumeMounts", []).append(
- {"name": "wandb-source-code-volume", "mountPath": CODE_MOUNT_DIR}
- )
- container["workingDir"] = project.resolved_working_dir
- # Only install deps when using the auto-assigned default base image.
- # User-provided base images are expected to already have deps.
- if install_deps:
- _set_container_command_with_dep_install(
- container,
- project.resolved_working_dir,
- f"{job_dir}/requirements.frozen.txt",
- )
- if api_key_env is None:
- for env in container.get("env", []):
- if env["name"] == "WANDB_API_KEY":
- api_key_env = env
- break
- return api_key_env
- def apply_code_mount_configuration_emptydir(
- manifest: dict | list, project: LaunchProject, api: Api
- ) -> None:
- """Apply emptyDir code mount configuration when no PVC is available.
- Uses an init container to fetch code into an emptyDir volume, which is
- then mounted into all main containers.
- Args:
- manifest: The manifest to modify.
- project: The launch project.
- api: The internal API instance (for base_url).
- """
- base_url = api.settings("base_url")
- source_type = project.job_source_type
- source_info = project.job_source_info
- install_deps = project._auto_default_base_image
- # Validate before mutating the manifest.
- if source_type not in ("artifact", "repo"):
- raise LaunchError(
- f"Cannot use emptyDir code mount for unknown source type: {source_type!r}"
- )
- job_dir = f"{CODE_MOUNT_DIR}/.job"
- for pod in yield_pods(manifest):
- spec = pod["spec"]
- spec.setdefault("volumes", []).append(
- {"name": "wandb-source-code-volume", "emptyDir": {}}
- )
- api_key_env = _configure_containers_for_code_mount(
- pod, project, install_deps, job_dir
- )
- fetch_script = _build_code_fetch_script(
- source_type, source_info, install_deps, job_dir
- )
- init_container = _build_source_init_container(
- fetch_script, base_url, api_key_env
- )
- spec.setdefault("initContainers", []).append(init_container)
|