kubernetes_runner.py 54 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481
  1. """Implementation of KubernetesRunner class for wandb launch."""
  2. from __future__ import annotations
  3. import asyncio
  4. import base64
  5. import datetime
  6. import json
  7. import logging
  8. import os
  9. import shlex
  10. import time
  11. from collections.abc import Iterator
  12. from typing import Any
  13. import yaml
  14. import wandb
  15. from wandb.apis.internal import Api
  16. from wandb.sdk.launch.agent.agent import LaunchAgent
  17. from wandb.sdk.launch.environment.abstract import AbstractEnvironment
  18. from wandb.sdk.launch.registry.abstract import AbstractRegistry
  19. from wandb.sdk.launch.registry.azure_container_registry import AzureContainerRegistry
  20. from wandb.sdk.launch.registry.local_registry import LocalRegistry
  21. from wandb.sdk.launch.runner.abstract import Status
  22. from wandb.sdk.launch.runner.kubernetes_monitor import (
  23. WANDB_K8S_LABEL_AGENT,
  24. WANDB_K8S_LABEL_AUXILIARY_RESOURCE,
  25. WANDB_K8S_LABEL_MONITOR,
  26. WANDB_K8S_RUN_ID,
  27. CustomResource,
  28. LaunchKubernetesMonitor,
  29. )
  30. from wandb.sdk.launch.utils import (
  31. recursive_macro_sub,
  32. sanitize_identifiers_for_k8s,
  33. yield_containers,
  34. )
  35. from wandb.sdk.lib.retry import ExponentialBackoff, retry_async
  36. from wandb.util import get_module
  37. from .._project_spec import EntryPoint, LaunchProject
  38. from ..errors import LaunchError
  39. from ..utils import (
  40. CODE_MOUNT_DIR,
  41. LOG_PREFIX,
  42. MAX_ENV_LENGTHS,
  43. PROJECT_SYNCHRONOUS,
  44. get_kube_context_and_api_client,
  45. make_k8s_label_safe,
  46. make_name_dns_safe,
  47. )
  48. from .abstract import AbstractRun, AbstractRunner
  49. get_module(
  50. "kubernetes_asyncio",
  51. required="Kubernetes runner requires the kubernetes package. Please install it with `pip install wandb[launch]`.",
  52. )
  53. import kubernetes_asyncio # type: ignore # noqa: E402
  54. from kubernetes_asyncio import client # noqa: E402
  55. from kubernetes_asyncio.client.api.apps_v1_api import ( # type: ignore # noqa: E402
  56. AppsV1Api,
  57. )
  58. from kubernetes_asyncio.client.api.batch_v1_api import ( # type: ignore # noqa: E402
  59. BatchV1Api,
  60. )
  61. from kubernetes_asyncio.client.api.core_v1_api import ( # type: ignore # noqa: E402
  62. CoreV1Api,
  63. )
  64. from kubernetes_asyncio.client.api.custom_objects_api import ( # type: ignore # noqa: E402
  65. CustomObjectsApi,
  66. )
  67. from kubernetes_asyncio.client.api.networking_v1_api import ( # type: ignore # noqa: E402
  68. NetworkingV1Api,
  69. )
  70. from kubernetes_asyncio.client.models.v1_secret import ( # type: ignore # noqa: E402
  71. V1Secret,
  72. )
  73. from kubernetes_asyncio.client.rest import ApiException # type: ignore # noqa: E402
  74. TIMEOUT = 5
  75. API_KEY_SECRET_MAX_RETRIES = 5
  76. _logger = logging.getLogger(__name__)
  77. SOURCE_CODE_PVC_MOUNT_PATH = os.environ.get("WANDB_LAUNCH_CODE_PVC_MOUNT_PATH")
  78. SOURCE_CODE_PVC_NAME = os.environ.get("WANDB_LAUNCH_CODE_PVC_NAME")
  79. class KubernetesSubmittedRun(AbstractRun):
  80. """Wrapper for a launched run on Kubernetes."""
  81. def __init__(
  82. self,
  83. batch_api: BatchV1Api,
  84. core_api: CoreV1Api,
  85. apps_api: AppsV1Api,
  86. network_api: NetworkingV1Api,
  87. name: str,
  88. namespace: str | None = "default",
  89. secret: V1Secret | None = None,
  90. auxiliary_resource_label_key: str | None = None,
  91. ) -> None:
  92. """Initialize a KubernetesSubmittedRun.
  93. Other implementations of the AbstractRun interface poll on the run
  94. when `get_status` is called, but KubernetesSubmittedRun uses
  95. Kubernetes watch streams to update the run status. One thread handles
  96. events from the job object and another thread handles events from the
  97. rank 0 pod. These threads updated the `_status` attributed of the
  98. KubernetesSubmittedRun object. When `get_status` is called, the
  99. `_status` attribute is returned.
  100. Args:
  101. batch_api: Kubernetes BatchV1Api object.
  102. core_api: Kubernetes CoreV1Api object.
  103. network_api: Kubernetes NetworkV1Api object.
  104. name: Name of the job.
  105. namespace: Kubernetes namespace.
  106. secret: Kubernetes secret.
  107. Returns:
  108. None.
  109. """
  110. self.batch_api = batch_api
  111. self.core_api = core_api
  112. self.apps_api = apps_api
  113. self.network_api = network_api
  114. self.name = name
  115. self.namespace = namespace
  116. self._fail_count = 0
  117. self.secret = secret
  118. self.auxiliary_resource_label_key = auxiliary_resource_label_key
  119. @property
  120. def id(self) -> str:
  121. """Return the run id."""
  122. return self.name
  123. async def get_logs(self) -> str | None:
  124. try:
  125. pods = await self.core_api.list_namespaced_pod(
  126. label_selector=f"job-name={self.name}", namespace=self.namespace
  127. )
  128. pod_names = [pi.metadata.name for pi in pods.items]
  129. if not pod_names:
  130. wandb.termwarn(f"Found no pods for kubernetes job: {self.name}")
  131. return None
  132. logs = await self.core_api.read_namespaced_pod_log(
  133. name=pod_names[0], namespace=self.namespace
  134. )
  135. if logs:
  136. return str(logs)
  137. else:
  138. wandb.termwarn(f"No logs for kubernetes pod(s): {pod_names}")
  139. return None
  140. except Exception as e:
  141. wandb.termerror(f"{LOG_PREFIX}Failed to get pod logs: {e}")
  142. return None
  143. async def wait(self) -> bool:
  144. """Wait for the run to finish.
  145. Returns:
  146. True if the run finished successfully, False otherwise.
  147. """
  148. while True:
  149. status = await self.get_status()
  150. wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status.state}")
  151. if status.state in ["finished", "failed", "preempted"]:
  152. break
  153. await asyncio.sleep(5)
  154. await self._delete_secret()
  155. return (
  156. status.state == "finished"
  157. ) # todo: not sure if this (copied from aws runner) is the right approach? should we return false on failure
  158. async def get_status(self) -> Status:
  159. status = LaunchKubernetesMonitor.get_status(self.name)
  160. if status in ["stopped", "failed", "finished", "preempted"]:
  161. await self._delete_secret()
  162. return status
  163. async def cancel(self) -> None:
  164. """Cancel the run."""
  165. try:
  166. await self.batch_api.delete_namespaced_job(
  167. namespace=self.namespace,
  168. name=self.name,
  169. )
  170. await self._delete_secret()
  171. except ApiException as e:
  172. raise LaunchError(
  173. f"Failed to delete Kubernetes Job {self.name} in namespace {self.namespace}: {str(e)}"
  174. ) from e
  175. async def _delete_secret(self) -> None:
  176. # Cleanup secret if not running in a helm-managed context
  177. if not os.environ.get("WANDB_RELEASE_NAME") and self.secret:
  178. await self.core_api.delete_namespaced_secret(
  179. name=self.secret.metadata.name,
  180. namespace=self.secret.metadata.namespace,
  181. )
  182. self.secret = None
  183. async def _delete_auxiliary_resources_by_label(self) -> None:
  184. if self.auxiliary_resource_label_key is None:
  185. return
  186. label_selector = (
  187. f"{WANDB_K8S_LABEL_AUXILIARY_RESOURCE}={self.auxiliary_resource_label_key}"
  188. )
  189. try:
  190. resource_cleanups = [
  191. (self.core_api, "service"),
  192. (self.batch_api, "job"),
  193. (self.core_api, "pod"),
  194. (self.core_api, "secret"),
  195. (self.apps_api, "deployment"),
  196. (self.network_api, "network_policy"),
  197. ]
  198. for api_client, resource_type in resource_cleanups:
  199. try:
  200. list_method = getattr(
  201. api_client, f"list_namespaced_{resource_type}"
  202. )
  203. delete_method = getattr(
  204. api_client, f"delete_namespaced_{resource_type}"
  205. )
  206. # List resources with our label
  207. resources = await list_method(
  208. namespace=self.namespace, label_selector=label_selector
  209. )
  210. # Delete each resource
  211. for resource in resources.items:
  212. await delete_method(
  213. name=resource.metadata.name, namespace=self.namespace
  214. )
  215. except (AttributeError, ApiException) as e:
  216. wandb.termwarn(f"Could not clean up {resource_type}: {e}")
  217. except Exception as e:
  218. wandb.termwarn(f"Failed to clean up some auxiliary resources: {e}")
  219. class CrdSubmittedRun(AbstractRun):
  220. """Run submitted to a CRD backend, e.g. Volcano."""
  221. def __init__(
  222. self,
  223. group: str,
  224. version: str,
  225. plural: str,
  226. name: str,
  227. namespace: str,
  228. core_api: CoreV1Api,
  229. custom_api: CustomObjectsApi,
  230. ) -> None:
  231. """Create a run object for tracking the progress of a CRD.
  232. Args:
  233. group: The API group of the CRD.
  234. version: The API version of the CRD.
  235. plural: The plural name of the CRD.
  236. name: The name of the CRD instance.
  237. namespace: The namespace of the CRD instance.
  238. core_api: The Kubernetes core API client.
  239. custom_api: The Kubernetes custom object API client.
  240. Raises:
  241. LaunchError: If the CRD instance does not exist.
  242. """
  243. self.group = group
  244. self.version = version
  245. self.plural = plural
  246. self.name = name
  247. self.namespace = namespace
  248. self.core_api = core_api
  249. self.custom_api = custom_api
  250. self._fail_count = 0
  251. @property
  252. def id(self) -> str:
  253. """Get the name of the custom object."""
  254. return self.name
  255. async def get_logs(self) -> str | None:
  256. """Get logs for custom object."""
  257. # TODO: test more carefully once we release multi-node support
  258. logs: dict[str, str | None] = {}
  259. try:
  260. pods = await self.core_api.list_namespaced_pod(
  261. label_selector=f"wandb/run-id={self.name}", namespace=self.namespace
  262. )
  263. pod_names = [pi.metadata.name for pi in pods.items]
  264. for pod_name in pod_names:
  265. logs[pod_name] = await self.core_api.read_namespaced_pod_log(
  266. name=pod_name, namespace=self.namespace
  267. )
  268. except ApiException as e:
  269. wandb.termwarn(f"Failed to get logs for {self.name}: {str(e)}")
  270. return None
  271. if not logs:
  272. return None
  273. logs_as_array = [f"Pod {pod_name}:\n{log}" for pod_name, log in logs.items()]
  274. return "\n".join(logs_as_array)
  275. async def get_status(self) -> Status:
  276. """Get status of custom object."""
  277. return LaunchKubernetesMonitor.get_status(self.name)
  278. async def cancel(self) -> None:
  279. """Cancel the custom object."""
  280. try:
  281. await self.custom_api.delete_namespaced_custom_object(
  282. group=self.group,
  283. version=self.version,
  284. namespace=self.namespace,
  285. plural=self.plural,
  286. name=self.name,
  287. )
  288. except ApiException as e:
  289. raise LaunchError(
  290. f"Failed to delete CRD {self.name} in namespace {self.namespace}: {str(e)}"
  291. ) from e
  292. async def wait(self) -> bool:
  293. """Wait for this custom object to finish running."""
  294. while True:
  295. status = await self.get_status()
  296. wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
  297. if status.state in ["finished", "failed", "preempted"]:
  298. return status.state == "finished"
  299. await asyncio.sleep(5)
  300. class KubernetesRunner(AbstractRunner):
  301. """Launches runs onto kubernetes."""
  302. def __init__(
  303. self,
  304. api: Api,
  305. backend_config: dict[str, Any],
  306. environment: AbstractEnvironment,
  307. registry: AbstractRegistry,
  308. ) -> None:
  309. """Create a Kubernetes runner.
  310. Args:
  311. api: The API client object.
  312. backend_config: The backend configuration.
  313. environment: The environment to launch runs into.
  314. Raises:
  315. LaunchError: If the Kubernetes configuration is invalid.
  316. """
  317. super().__init__(api, backend_config)
  318. self.environment = environment
  319. self.registry = registry
  320. def get_namespace(
  321. self, resource_args: dict[str, Any], context: dict[str, Any]
  322. ) -> str:
  323. """Get the namespace to launch into.
  324. Args:
  325. resource_args: The resource args to launch.
  326. context: The k8s config context.
  327. Returns:
  328. The namespace to launch into.
  329. """
  330. default_namespace = (
  331. context["context"].get("namespace", "default") if context else "default"
  332. )
  333. return ( # type: ignore[no-any-return]
  334. resource_args.get("metadata", {}).get("namespace")
  335. or resource_args.get(
  336. "namespace"
  337. ) # continue support for malformed namespace
  338. or self.backend_config.get("runner", {}).get("namespace")
  339. or default_namespace
  340. )
  341. async def _inject_defaults(
  342. self,
  343. resource_args: dict[str, Any],
  344. launch_project: LaunchProject,
  345. image_uri: str,
  346. namespace: str,
  347. core_api: CoreV1Api,
  348. ) -> tuple[dict[str, Any], V1Secret | None]:
  349. """Apply our default values, return job dict and api key secret.
  350. Args:
  351. resource_args (Dict[str, Any]): The resource args to launch.
  352. launch_project (LaunchProject): The launch project.
  353. builder (Optional[AbstractBuilder]): The builder.
  354. namespace (str): The namespace.
  355. core_api (CoreV1Api): The core api.
  356. Returns:
  357. Tuple[Dict[str, Any], Optional["V1Secret"]]: The resource args and api key secret.
  358. """
  359. job: dict[str, Any] = {
  360. "apiVersion": "batch/v1",
  361. "kind": "Job",
  362. }
  363. job.update(resource_args)
  364. job_metadata: dict[str, Any] = job.get("metadata", {})
  365. job_spec: dict[str, Any] = {"backoffLimit": 0, "ttlSecondsAfterFinished": 60}
  366. job_spec.update(job.get("spec", {}))
  367. pod_template: dict[str, Any] = job_spec.get("template", {})
  368. pod_spec: dict[str, Any] = {"restartPolicy": "Never"}
  369. pod_spec.update(pod_template.get("spec", {}))
  370. containers: list[dict[str, Any]] = pod_spec.get("containers", [{}])
  371. # Add labels to job metadata
  372. job_metadata.setdefault("labels", {})
  373. job_metadata["labels"][WANDB_K8S_RUN_ID] = launch_project.run_id
  374. job_metadata["labels"][WANDB_K8S_LABEL_MONITOR] = "true"
  375. if LaunchAgent.initialized():
  376. job_metadata["labels"][WANDB_K8S_LABEL_AGENT] = LaunchAgent.name()
  377. # name precedence: name in spec > generated name
  378. if not job_metadata.get("name"):
  379. job_metadata["generateName"] = make_name_dns_safe(
  380. f"launch-{launch_project.target_entity}-{launch_project.target_project}-"
  381. )
  382. job_metadata["namespace"] = namespace
  383. for i, cont in enumerate(containers):
  384. if "name" not in cont:
  385. cont["name"] = cont.get("name", "launch" + str(i))
  386. if "securityContext" not in cont:
  387. cont["securityContext"] = {
  388. "allowPrivilegeEscalation": False,
  389. "capabilities": {"drop": ["ALL"]},
  390. "seccompProfile": {"type": "RuntimeDefault"},
  391. }
  392. entry_point = (
  393. launch_project.override_entrypoint or launch_project.get_job_entry_point()
  394. )
  395. if launch_project.docker_image:
  396. # dont specify run id if user provided image, could have multiple runs
  397. containers[0]["image"] = image_uri
  398. # TODO: handle secret pulling image from registry
  399. elif not any(["image" in cont for cont in containers]):
  400. assert entry_point is not None
  401. # in the non instance case we need to make an imagePullSecret
  402. # so the new job can pull the image
  403. containers[0]["image"] = image_uri
  404. secret = await maybe_create_imagepull_secret(
  405. core_api, self.registry, launch_project.run_id, namespace
  406. )
  407. if secret is not None:
  408. pod_spec["imagePullSecrets"] = [
  409. {"name": f"regcred-{launch_project.run_id}"}
  410. ]
  411. inject_entrypoint_and_args(
  412. containers,
  413. entry_point,
  414. launch_project.override_args,
  415. launch_project.override_entrypoint is not None,
  416. )
  417. env_vars = launch_project.get_env_vars_dict(
  418. self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
  419. )
  420. api_key_secret = None
  421. for cont in containers:
  422. # Add our env vars to user supplied env vars
  423. env = cont.get("env") or []
  424. for key, value in env_vars.items():
  425. if (
  426. key == "WANDB_API_KEY"
  427. and value
  428. and (
  429. LaunchAgent.initialized()
  430. or self.backend_config[PROJECT_SYNCHRONOUS]
  431. )
  432. ):
  433. # Override API key with secret. TODO: Do the same for other runners
  434. release_name = os.environ.get("WANDB_RELEASE_NAME")
  435. secret_name = "wandb-api-key"
  436. if release_name:
  437. secret_name += f"-{release_name}"
  438. else:
  439. secret_name += f"-{launch_project.run_id}"
  440. def handle_exception(e):
  441. wandb.termwarn(
  442. f"Exception when ensuring Kubernetes API key secret: {e}. Retrying..."
  443. )
  444. api_key_secret = await retry_async(
  445. backoff=ExponentialBackoff(
  446. initial_sleep=datetime.timedelta(seconds=1),
  447. max_sleep=datetime.timedelta(minutes=1),
  448. max_retries=API_KEY_SECRET_MAX_RETRIES,
  449. ),
  450. fn=ensure_api_key_secret,
  451. on_exc=handle_exception,
  452. core_api=core_api,
  453. secret_name=secret_name,
  454. namespace=namespace,
  455. api_key=value,
  456. )
  457. env.append(
  458. {
  459. "name": key,
  460. "valueFrom": {
  461. "secretKeyRef": {
  462. "name": secret_name,
  463. "key": "password",
  464. }
  465. },
  466. }
  467. )
  468. else:
  469. env.append({"name": key, "value": value})
  470. cont["env"] = env
  471. pod_spec["containers"] = containers
  472. pod_template["spec"] = pod_spec
  473. job_spec["template"] = pod_template
  474. job["spec"] = job_spec
  475. job["metadata"] = job_metadata
  476. add_label_to_pods(
  477. job,
  478. WANDB_K8S_LABEL_MONITOR,
  479. "true",
  480. )
  481. if launch_project.job_base_image:
  482. if SOURCE_CODE_PVC_NAME and SOURCE_CODE_PVC_MOUNT_PATH:
  483. apply_code_mount_configuration(
  484. job,
  485. launch_project,
  486. )
  487. else:
  488. apply_code_mount_configuration_emptydir(
  489. job,
  490. launch_project,
  491. self._api,
  492. )
  493. # Add wandb.ai/agent: current agent label on all pods
  494. if LaunchAgent.initialized():
  495. add_label_to_pods(
  496. job,
  497. WANDB_K8S_LABEL_AGENT,
  498. LaunchAgent.name(),
  499. )
  500. return job, api_key_secret
  501. async def _wait_for_resource_ready(
  502. self,
  503. api_client: kubernetes_asyncio.client.ApiClient,
  504. config: dict[str, Any],
  505. namespace: str,
  506. timeout_seconds: int = 300,
  507. ) -> None:
  508. """Wait for a Kubernetes resource to be ready.
  509. Args:
  510. api_client: The Kubernetes API client.
  511. config: The resource configuration.
  512. namespace: The namespace where the resource was created.
  513. timeout_seconds: Maximum time to wait for readiness.
  514. """
  515. resource_kind = config.get("kind")
  516. resource_name = config.get("metadata", {}).get("name")
  517. if not resource_kind or not resource_name:
  518. wandb.termerror(
  519. f"{LOG_PREFIX}Cannot wait for resource without kind or name"
  520. )
  521. return
  522. wandb.termlog(
  523. f"{LOG_PREFIX}Waiting for {resource_kind} '{resource_name}' to be ready..."
  524. )
  525. start_time = time.time()
  526. if resource_kind == "Deployment":
  527. await self._wait_for_deployment_ready(
  528. api_client, resource_name, namespace, timeout_seconds
  529. )
  530. elif resource_kind == "Service":
  531. await self._wait_for_service_ready(
  532. api_client, resource_name, namespace, timeout_seconds
  533. )
  534. elif resource_kind == "Pod":
  535. await self._wait_for_pod_ready(
  536. api_client, resource_name, namespace, timeout_seconds
  537. )
  538. else:
  539. wandb.termlog(
  540. f"{LOG_PREFIX}No specific readiness check for {resource_kind}, waiting 5 seconds..."
  541. )
  542. await asyncio.sleep(5)
  543. elapsed = time.time() - start_time
  544. wandb.termlog(
  545. f"{LOG_PREFIX}{resource_kind} '{resource_name}' is ready after {elapsed:.1f}s"
  546. )
  547. async def _wait_for_deployment_ready(
  548. self,
  549. api_client: kubernetes_asyncio.client.ApiClient,
  550. name: str,
  551. namespace: str,
  552. timeout_seconds: int,
  553. ) -> None:
  554. """Wait for a Deployment to be ready."""
  555. apps_api = kubernetes_asyncio.client.AppsV1Api(api_client)
  556. async def check_deployment_ready():
  557. deployment = await apps_api.read_namespaced_deployment(
  558. name=name, namespace=namespace
  559. )
  560. status = deployment.status
  561. if status.ready_replicas and status.replicas:
  562. return status.ready_replicas >= status.replicas
  563. return False
  564. await self._wait_with_timeout(check_deployment_ready, timeout_seconds, name)
  565. async def _wait_for_service_ready(
  566. self,
  567. api_client: kubernetes_asyncio.client.ApiClient,
  568. name: str,
  569. namespace: str,
  570. timeout_seconds: int,
  571. ) -> None:
  572. """Wait for a Service to have endpoints."""
  573. core_api = kubernetes_asyncio.client.CoreV1Api(api_client)
  574. async def check_service_ready():
  575. endpoints = await core_api.read_namespaced_endpoints(
  576. name=name, namespace=namespace
  577. )
  578. if endpoints.subsets:
  579. for subset in endpoints.subsets:
  580. if subset.addresses: # These are ready pod addresses
  581. return True
  582. return False
  583. await self._wait_with_timeout(check_service_ready, timeout_seconds, name)
  584. async def _wait_for_pod_ready(
  585. self,
  586. api_client: kubernetes_asyncio.client.ApiClient,
  587. name: str,
  588. namespace: str,
  589. timeout_seconds: int,
  590. ) -> None:
  591. """Wait for a Pod to be ready."""
  592. core_api = kubernetes_asyncio.client.CoreV1Api(api_client)
  593. async def check_pod_ready():
  594. pod = await core_api.read_namespaced_pod(name=name, namespace=namespace)
  595. if pod.status.phase == "Running":
  596. if pod.status.container_statuses:
  597. return all(status.ready for status in pod.status.container_statuses)
  598. return True
  599. return False
  600. await self._wait_with_timeout(check_pod_ready, timeout_seconds, name)
  601. async def _wait_with_timeout(
  602. self, check_func, timeout_seconds: int, name: str
  603. ) -> None:
  604. """Generic timeout wrapper for readiness checks."""
  605. start_time = time.time()
  606. while time.time() - start_time < timeout_seconds:
  607. try:
  608. if await check_func():
  609. return
  610. except kubernetes_asyncio.client.ApiException as e:
  611. if e.status == 404:
  612. pass
  613. else:
  614. wandb.termerror(
  615. f"{LOG_PREFIX}Error waiting for resource '{name}': {e}"
  616. )
  617. raise
  618. except Exception as e:
  619. wandb.termerror(f"{LOG_PREFIX}Error waiting for resource '{name}': {e}")
  620. raise
  621. await asyncio.sleep(2)
  622. raise LaunchError(
  623. f"Resource '{name}' not ready within {timeout_seconds} seconds"
  624. )
  625. async def _prepare_resource(
  626. self,
  627. api_client: kubernetes_asyncio.client.ApiClient,
  628. config: dict[str, Any],
  629. namespace: str,
  630. run_id: str,
  631. launch_project: LaunchProject,
  632. api_key_secret: V1Secret | None = None,
  633. wait_for_ready: bool = True,
  634. wait_timeout: int = 300,
  635. auxiliary_resource_label_value: str | None = None,
  636. ) -> None:
  637. """Prepare a service for launch.
  638. Args:
  639. api_client: The Kubernetes API client.
  640. config: The resource configuration to prepare.
  641. namespace: The namespace to create the resource in.
  642. run_id: The run ID to label the resource with.
  643. launch_project: The launch project to get environment variables from.
  644. api_key_secret: The API key secret to inject.
  645. wait_for_ready: Whether to wait for the resource to be ready after creation.
  646. wait_timeout: Maximum time in seconds to wait for resource readiness.
  647. """
  648. config.setdefault("metadata", {})
  649. config["metadata"].setdefault("labels", {})
  650. config["metadata"]["labels"][WANDB_K8S_RUN_ID] = run_id
  651. config["metadata"]["labels"]["wandb.ai/created-by"] = "launch-agent"
  652. if auxiliary_resource_label_value:
  653. config["metadata"]["labels"][WANDB_K8S_LABEL_AUXILIARY_RESOURCE] = (
  654. auxiliary_resource_label_value
  655. )
  656. env_vars = launch_project.get_env_vars_dict(
  657. self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
  658. )
  659. wandb_config_env = {
  660. "WANDB_CONFIG": env_vars.get("WANDB_CONFIG", "{}"),
  661. }
  662. add_wandb_env(config, wandb_config_env)
  663. if auxiliary_resource_label_value:
  664. add_label_to_pods(
  665. config,
  666. WANDB_K8S_LABEL_AUXILIARY_RESOURCE,
  667. auxiliary_resource_label_value,
  668. )
  669. if api_key_secret:
  670. for cont in yield_containers(config):
  671. env = cont.setdefault("env", [])
  672. env.append(
  673. {
  674. "name": "WANDB_API_KEY",
  675. "valueFrom": {
  676. "secretKeyRef": {
  677. "name": api_key_secret.metadata.name,
  678. "key": "password",
  679. }
  680. },
  681. }
  682. )
  683. cont["env"] = env
  684. try:
  685. sanitize_identifiers_for_k8s(config)
  686. await kubernetes_asyncio.utils.create_from_dict(
  687. api_client, config, namespace=namespace
  688. )
  689. if wait_for_ready:
  690. await self._wait_for_resource_ready(
  691. api_client, config, namespace, wait_timeout
  692. )
  693. except Exception as e:
  694. wandb.termerror(f"{LOG_PREFIX}Failed to create Kubernetes resource: {e}")
  695. raise LaunchError(f"Failed to create Kubernetes resource: {e}")
  696. async def run(
  697. self, launch_project: LaunchProject, image_uri: str
  698. ) -> AbstractRun | None:
  699. """Execute a launch project on Kubernetes.
  700. Args:
  701. launch_project: The launch project to execute.
  702. builder: The builder to use to build the image.
  703. Returns:
  704. The run object if the run was successful, otherwise None.
  705. """
  706. await LaunchKubernetesMonitor.ensure_initialized()
  707. resource_args = launch_project.fill_macros(image_uri).get("kubernetes", {})
  708. if not resource_args:
  709. wandb.termlog(
  710. f"{LOG_PREFIX}Note: no resource args specified. Add a "
  711. "Kubernetes yaml spec or other options in a json file "
  712. "with --resource-args <json>."
  713. )
  714. _logger.info(f"Running Kubernetes job with resource args: {resource_args}")
  715. context, api_client = await get_kube_context_and_api_client(
  716. kubernetes_asyncio, resource_args
  717. )
  718. # If using pvc for code mount, move code there.
  719. use_emptydir_code_mount = False
  720. if launch_project.job_base_image is not None:
  721. if SOURCE_CODE_PVC_NAME and SOURCE_CODE_PVC_MOUNT_PATH:
  722. code_subdir = launch_project.get_image_source_string()
  723. launch_project.change_project_dir(
  724. os.path.join(SOURCE_CODE_PVC_MOUNT_PATH, code_subdir)
  725. )
  726. else:
  727. use_emptydir_code_mount = True
  728. # If the user specified an alternate api, we need will execute this
  729. # run by creating a custom object.
  730. api_version = resource_args.get("apiVersion", "batch/v1")
  731. if api_version not in ["batch/v1", "batch/v1beta1"]:
  732. env_vars = launch_project.get_env_vars_dict(
  733. self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
  734. )
  735. # Crawl the resource args and add our env vars to the containers.
  736. add_wandb_env(resource_args, env_vars)
  737. # Add our labels to the resource args. This is necessary for the
  738. # agent to find the custom object later on.
  739. resource_args["metadata"] = resource_args.get("metadata", {})
  740. resource_args["metadata"]["labels"] = resource_args["metadata"].get(
  741. "labels", {}
  742. )
  743. resource_args["metadata"]["labels"][WANDB_K8S_LABEL_MONITOR] = "true"
  744. # Crawl the resource arsg and add our labels to the pods. This is
  745. # necessary for the agent to find the pods later on.
  746. add_label_to_pods(
  747. resource_args,
  748. WANDB_K8S_LABEL_MONITOR,
  749. "true",
  750. )
  751. # Add wandb.ai/agent: current agent label on all pods
  752. if LaunchAgent.initialized():
  753. add_label_to_pods(
  754. resource_args,
  755. WANDB_K8S_LABEL_AGENT,
  756. LaunchAgent.name(),
  757. )
  758. resource_args["metadata"]["labels"][WANDB_K8S_LABEL_AGENT] = (
  759. LaunchAgent.name()
  760. )
  761. if launch_project.job_base_image:
  762. if use_emptydir_code_mount:
  763. apply_code_mount_configuration_emptydir(
  764. resource_args, launch_project, self._api
  765. )
  766. else:
  767. apply_code_mount_configuration(resource_args, launch_project)
  768. overrides = {}
  769. if launch_project.override_args:
  770. overrides["args"] = launch_project.override_args
  771. if launch_project.override_entrypoint:
  772. overrides["command"] = launch_project.override_entrypoint.command
  773. add_entrypoint_args_overrides(
  774. resource_args,
  775. overrides,
  776. )
  777. api = client.CustomObjectsApi(api_client)
  778. # Infer the attributes of a custom object from the apiVersion and/or
  779. # a kind: attribute in the resource args.
  780. namespace = self.get_namespace(resource_args, context)
  781. group, version, *_ = api_version.split("/")
  782. group = resource_args.get("group", group)
  783. version = resource_args.get("version", version)
  784. kind = resource_args.get("kind", version)
  785. plural = f"{kind.lower()}s"
  786. custom_resource = CustomResource(
  787. group=group,
  788. version=version,
  789. plural=plural,
  790. )
  791. LaunchKubernetesMonitor.monitor_namespace(
  792. namespace, custom_resource=custom_resource
  793. )
  794. try:
  795. response = await api.create_namespaced_custom_object(
  796. group=group,
  797. version=version,
  798. namespace=namespace,
  799. plural=plural,
  800. body=resource_args,
  801. )
  802. except ApiException as e:
  803. body = json.loads(e.body)
  804. body_yaml = yaml.dump(body)
  805. raise LaunchError(
  806. f"Error creating CRD of kind {kind}: {e.status} {e.reason}\n{body_yaml}"
  807. ) from e
  808. name = response.get("metadata", {}).get("name")
  809. _logger.info(f"Created {kind} {response['metadata']['name']}")
  810. submitted_run = CrdSubmittedRun(
  811. name=name,
  812. group=group,
  813. version=version,
  814. namespace=namespace,
  815. plural=plural,
  816. core_api=client.CoreV1Api(api_client),
  817. custom_api=api,
  818. )
  819. if self.backend_config[PROJECT_SYNCHRONOUS]:
  820. await submitted_run.wait()
  821. return submitted_run
  822. batch_api = kubernetes_asyncio.client.BatchV1Api(api_client)
  823. core_api = kubernetes_asyncio.client.CoreV1Api(api_client)
  824. apps_api = kubernetes_asyncio.client.AppsV1Api(api_client)
  825. network_api = kubernetes_asyncio.client.NetworkingV1Api(api_client)
  826. namespace = self.get_namespace(resource_args, context)
  827. job, secret = await self._inject_defaults(
  828. resource_args, launch_project, image_uri, namespace, core_api
  829. )
  830. update_dict = {
  831. "project_name": launch_project.target_project,
  832. "entity_name": launch_project.target_entity,
  833. "run_id": launch_project.run_id,
  834. "run_name": launch_project.name,
  835. "image_uri": image_uri,
  836. "author": launch_project.author,
  837. }
  838. update_dict.update(os.environ)
  839. additional_services: list[dict[str, Any]] = recursive_macro_sub(
  840. launch_project.launch_spec.get("additional_services", []), update_dict
  841. )
  842. auxiliary_resource_label_value = make_k8s_label_safe(
  843. f"aux-{launch_project.target_entity}-{launch_project.target_project}-{launch_project.run_id}"
  844. )
  845. if additional_services:
  846. wandb.termlog(
  847. f"{LOG_PREFIX}Creating additional services: {additional_services}"
  848. )
  849. wait_for_ready = resource_args.get("wait_for_ready", True)
  850. wait_timeout = resource_args.get("wait_timeout", 300)
  851. await asyncio.gather(
  852. *[
  853. self._prepare_resource(
  854. api_client,
  855. resource.get("config", {}),
  856. namespace,
  857. launch_project.run_id,
  858. launch_project,
  859. secret,
  860. wait_for_ready,
  861. wait_timeout,
  862. auxiliary_resource_label_value,
  863. )
  864. for resource in additional_services
  865. if resource.get("config", {})
  866. ]
  867. )
  868. msg = "Creating Kubernetes job"
  869. if "name" in resource_args:
  870. msg += f": {resource_args['name']}"
  871. _logger.info(msg)
  872. try:
  873. response = await kubernetes_asyncio.utils.create_from_dict(
  874. api_client, job, namespace=namespace
  875. )
  876. except kubernetes_asyncio.utils.FailToCreateError as e:
  877. for exc in e.api_exceptions:
  878. resp = json.loads(exc.body)
  879. msg = resp.get("message")
  880. code = resp.get("code")
  881. raise LaunchError(
  882. f"Failed to create Kubernetes job for run {launch_project.run_id} ({code} {exc.reason}): {msg}"
  883. )
  884. except Exception as e:
  885. raise LaunchError(
  886. f"Unexpected exception when creating Kubernetes job: {str(e)}\n"
  887. )
  888. job_response = response[0]
  889. job_name = job_response.metadata.name
  890. LaunchKubernetesMonitor.monitor_namespace(namespace)
  891. submitted_job = KubernetesSubmittedRun(
  892. batch_api,
  893. core_api,
  894. apps_api,
  895. network_api,
  896. job_name,
  897. namespace,
  898. secret,
  899. auxiliary_resource_label_value,
  900. )
  901. if self.backend_config[PROJECT_SYNCHRONOUS]:
  902. await submitted_job.wait()
  903. return submitted_job
  904. def inject_entrypoint_and_args(
  905. containers: list[dict],
  906. entry_point: EntryPoint | None,
  907. override_args: list[str],
  908. should_override_entrypoint: bool,
  909. ) -> None:
  910. """Inject the entrypoint and args into the containers.
  911. Args:
  912. containers: The containers to inject the entrypoint and args into.
  913. entry_point: The entrypoint to inject.
  914. override_args: The args to inject.
  915. should_override_entrypoint: Whether to override the entrypoint.
  916. Returns:
  917. None
  918. """
  919. for i in range(len(containers)):
  920. if override_args:
  921. containers[i]["args"] = override_args
  922. if entry_point and (
  923. not containers[i].get("command") or should_override_entrypoint
  924. ):
  925. containers[i]["command"] = entry_point.command
  926. async def ensure_api_key_secret(
  927. core_api: CoreV1Api,
  928. secret_name: str,
  929. namespace: str,
  930. api_key: str,
  931. ) -> V1Secret:
  932. """Create a secret containing a user's wandb API key.
  933. Args:
  934. core_api: The Kubernetes CoreV1Api object.
  935. secret_name: The name to use for the secret.
  936. namespace: The namespace to create the secret in.
  937. api_key: The user's wandb API key
  938. Returns:
  939. The created secret
  940. """
  941. secret_data = {"password": base64.b64encode(api_key.encode()).decode()}
  942. labels = {"wandb.ai/created-by": "launch-agent"}
  943. secret = client.V1Secret(
  944. data=secret_data,
  945. metadata=client.V1ObjectMeta(
  946. name=secret_name, namespace=namespace, labels=labels
  947. ),
  948. kind="Secret",
  949. type="kubernetes.io/basic-auth",
  950. )
  951. try:
  952. try:
  953. return await core_api.create_namespaced_secret(namespace, secret)
  954. except ApiException as e:
  955. # 409 = conflict = secret already exists
  956. if e.status == 409:
  957. existing_secret = await core_api.read_namespaced_secret(
  958. name=secret_name, namespace=namespace
  959. )
  960. if existing_secret.data != secret_data:
  961. # If it's a previous secret made by launch agent, clean it up
  962. if (
  963. existing_secret.metadata.labels.get("wandb.ai/created-by")
  964. == "launch-agent"
  965. ):
  966. await core_api.delete_namespaced_secret(
  967. name=secret_name, namespace=namespace
  968. )
  969. return await core_api.create_namespaced_secret(
  970. namespace, secret
  971. )
  972. else:
  973. raise LaunchError(
  974. f"Kubernetes secret already exists in namespace {namespace} with incorrect data: {secret_name}"
  975. )
  976. return existing_secret
  977. raise
  978. except Exception as e:
  979. raise LaunchError(
  980. f"Exception when ensuring Kubernetes API key secret: {str(e)}\n"
  981. )
  982. async def maybe_create_imagepull_secret(
  983. core_api: CoreV1Api,
  984. registry: AbstractRegistry,
  985. run_id: str,
  986. namespace: str,
  987. ) -> V1Secret | None:
  988. """Create a secret for pulling images from a private registry.
  989. Args:
  990. core_api: The Kubernetes CoreV1Api object.
  991. registry: The registry to pull from.
  992. run_id: The run id.
  993. namespace: The namespace to create the secret in.
  994. Returns:
  995. A secret if one was created, otherwise None.
  996. """
  997. secret = None
  998. if isinstance(registry, (LocalRegistry, AzureContainerRegistry)):
  999. # Secret not required
  1000. return None
  1001. uname, token = await registry.get_username_password()
  1002. creds_info = {
  1003. "auths": {
  1004. registry.uri: {
  1005. "auth": base64.b64encode(f"{uname}:{token}".encode()).decode(),
  1006. # need an email but the use is deprecated
  1007. "email": "deprecated@wandblaunch.com",
  1008. }
  1009. }
  1010. }
  1011. secret_data = {
  1012. ".dockerconfigjson": base64.b64encode(json.dumps(creds_info).encode()).decode()
  1013. }
  1014. secret = client.V1Secret(
  1015. data=secret_data,
  1016. metadata=client.V1ObjectMeta(name=f"regcred-{run_id}", namespace=namespace),
  1017. kind="Secret",
  1018. type="kubernetes.io/dockerconfigjson",
  1019. )
  1020. try:
  1021. try:
  1022. return await core_api.create_namespaced_secret(namespace, secret)
  1023. except ApiException as e:
  1024. # 409 = conflict = secret already exists
  1025. if e.status == 409:
  1026. return await core_api.read_namespaced_secret(
  1027. name=f"regcred-{run_id}", namespace=namespace
  1028. )
  1029. raise
  1030. except Exception as e:
  1031. raise LaunchError(f"Exception when creating Kubernetes secret: {str(e)}\n")
  1032. def add_wandb_env(root: dict | list, env_vars: dict[str, str]) -> None:
  1033. """Injects wandb environment variables into specs.
  1034. Recursively walks the spec and injects the environment variables into
  1035. every container spec. Containers are identified by the "containers" key.
  1036. This function treats the WANDB_RUN_ID and WANDB_GROUP_ID environment variables
  1037. specially. If they are present in the spec, they will be overwritten. If a setting
  1038. for WANDB_RUN_ID is provided in env_vars, then that environment variable will only be
  1039. set in the first container modified by this function.
  1040. Args:
  1041. root: The spec to modify.
  1042. env_vars: The environment variables to inject.
  1043. Returns: None.
  1044. """
  1045. for cont in yield_containers(root):
  1046. env = cont.setdefault("env", [])
  1047. env.extend([{"name": key, "value": value} for key, value in env_vars.items()])
  1048. cont["env"] = env
  1049. # After we have set WANDB_RUN_ID once, we don't want to set it again
  1050. if "WANDB_RUN_ID" in env_vars:
  1051. env_vars.pop("WANDB_RUN_ID")
  1052. def yield_pods(manifest: Any) -> Iterator[dict]:
  1053. """Yield all pod specs in a manifest.
  1054. Recursively traverses the manifest and yields all pod specs. Pod specs are
  1055. identified by the presence of a "spec" key with a "containers" key in the
  1056. value.
  1057. """
  1058. if isinstance(manifest, list):
  1059. for item in manifest:
  1060. yield from yield_pods(item)
  1061. elif isinstance(manifest, dict):
  1062. if "spec" in manifest and "containers" in manifest["spec"]:
  1063. yield manifest
  1064. for value in manifest.values():
  1065. if isinstance(value, (dict, list)):
  1066. yield from yield_pods(value)
  1067. def add_label_to_pods(manifest: dict | list, label_key: str, label_value: str) -> None:
  1068. """Add a label to all pod specs in a manifest.
  1069. Recursively traverses the manifest and adds the label to all pod specs.
  1070. Pod specs are identified by the presence of a "spec" key with a "containers"
  1071. key in the value.
  1072. Args:
  1073. manifest: The manifest to modify.
  1074. label_key: The label key to add.
  1075. label_value: The label value to add.
  1076. Returns: None.
  1077. """
  1078. for pod in yield_pods(manifest):
  1079. metadata = pod.setdefault("metadata", {})
  1080. labels = metadata.setdefault("labels", {})
  1081. labels[label_key] = label_value
  1082. def add_entrypoint_args_overrides(manifest: dict | list, overrides: dict) -> None:
  1083. """Add entrypoint and args overrides to all containers in a manifest.
  1084. Recursively traverses the manifest and adds the entrypoint and args overrides
  1085. to all containers. Containers are identified by the presence of a "spec" key
  1086. with a "containers" key in the value.
  1087. Args:
  1088. manifest: The manifest to modify.
  1089. overrides: Dictionary with args and entrypoint keys.
  1090. Returns: None.
  1091. """
  1092. if isinstance(manifest, list):
  1093. for item in manifest:
  1094. add_entrypoint_args_overrides(item, overrides)
  1095. elif isinstance(manifest, dict):
  1096. if "spec" in manifest and "containers" in manifest["spec"]:
  1097. containers = manifest["spec"]["containers"]
  1098. for container in containers:
  1099. if "command" in overrides:
  1100. container["command"] = overrides["command"]
  1101. if "args" in overrides:
  1102. container["args"] = overrides["args"]
  1103. for value in manifest.values():
  1104. add_entrypoint_args_overrides(value, overrides)
  1105. def _set_container_command_with_dep_install(
  1106. container: dict,
  1107. working_dir: str,
  1108. requirements_path: str,
  1109. ) -> None:
  1110. """Set a container's command to install dependencies then exec the original command.
  1111. Replaces command/args with a shell one-liner that installs dependencies, checking in order:
  1112. 1. requirements.txt (user-provided)
  1113. 2. pyproject.toml (user-provided, installed via pip install .)
  1114. 3. requirements.frozen.txt (job artifact fallback)
  1115. Args:
  1116. container: The container spec to modify in place.
  1117. working_dir: The working directory where user dep files are expected.
  1118. requirements_path: Path to the frozen requirements fallback file.
  1119. """
  1120. original_command = container.get("command", [])
  1121. original_args = container.get("args", [])
  1122. original_cmd_str = " ".join(
  1123. shlex.quote(c) for c in original_command + original_args
  1124. )
  1125. if not original_cmd_str:
  1126. return
  1127. user_requirements = f"{working_dir}/requirements.txt"
  1128. pyproject = f"{working_dir}/pyproject.toml"
  1129. install_prefix = (
  1130. f"if [ -f {shlex.quote(user_requirements)} ]; then"
  1131. f" pip install -r {shlex.quote(user_requirements)};"
  1132. f" elif [ -f {shlex.quote(pyproject)} ]; then"
  1133. f" pip install {shlex.quote(working_dir)};"
  1134. f" elif [ -f {shlex.quote(requirements_path)} ]; then"
  1135. f" pip install -r {shlex.quote(requirements_path)};"
  1136. f" else echo 'No requirements file found'; fi"
  1137. )
  1138. container["command"] = ["/bin/sh", "-c"]
  1139. container["args"] = [f"{install_prefix} && exec {original_cmd_str}"]
  1140. def apply_code_mount_configuration(
  1141. manifest: dict | list, project: LaunchProject
  1142. ) -> None:
  1143. """Apply code mount configuration to all containers in a manifest.
  1144. Recursively traverses the manifest and adds the code mount configuration to
  1145. all containers. Containers are identified by the presence of a "spec" key
  1146. with a "containers" key in the value.
  1147. Args:
  1148. manifest: The manifest to modify.
  1149. project: The launch project.
  1150. Returns: None.
  1151. """
  1152. assert SOURCE_CODE_PVC_NAME is not None
  1153. source_dir = project.get_image_source_string()
  1154. for pod in yield_pods(manifest):
  1155. for container in yield_containers(pod):
  1156. if "volumeMounts" not in container:
  1157. container["volumeMounts"] = []
  1158. container["volumeMounts"].append(
  1159. {
  1160. "name": "wandb-source-code-volume",
  1161. "mountPath": CODE_MOUNT_DIR,
  1162. "subPath": source_dir,
  1163. }
  1164. )
  1165. container["workingDir"] = project.resolved_working_dir
  1166. if project._auto_default_base_image:
  1167. _set_container_command_with_dep_install(
  1168. container,
  1169. project.resolved_working_dir,
  1170. f"{CODE_MOUNT_DIR}/.job/requirements.frozen.txt",
  1171. )
  1172. spec = pod["spec"]
  1173. if "volumes" not in spec:
  1174. spec["volumes"] = []
  1175. spec["volumes"].append(
  1176. {
  1177. "name": "wandb-source-code-volume",
  1178. "persistentVolumeClaim": {
  1179. "claimName": SOURCE_CODE_PVC_NAME,
  1180. },
  1181. }
  1182. )
  1183. def _build_code_fetch_script(
  1184. source_type: str,
  1185. source_info: dict,
  1186. install_deps: bool,
  1187. job_dir: str,
  1188. ) -> str:
  1189. """Build the shell script for the init container to fetch source code.
  1190. Args:
  1191. source_type: Either "artifact" or "repo".
  1192. source_info: Source metadata from the launch project.
  1193. install_deps: Whether to also fetch the job artifact for frozen requirements.
  1194. job_dir: Path where the job artifact should be downloaded.
  1195. """
  1196. job_artifact = source_info.get("job_artifact", "")
  1197. chmod_suffix = f" && chmod -R a+w {CODE_MOUNT_DIR}/* || true && chmod -R a+w {CODE_MOUNT_DIR}/.* || true"
  1198. fetch_job_artifact = ""
  1199. if install_deps and job_artifact:
  1200. py_cmd = f"import wandb; wandb.Api().artifact({repr(job_artifact)}).download({repr(job_dir)})"
  1201. fetch_job_artifact = f" && python -c {shlex.quote(py_cmd)}"
  1202. if source_type == "artifact":
  1203. artifact_string = source_info.get("artifact_string", "")
  1204. py_cmd = f"import wandb; wandb.Api().artifact({repr(artifact_string)}).download({repr(CODE_MOUNT_DIR)})"
  1205. return f"python -c {shlex.quote(py_cmd)}" + fetch_job_artifact + chmod_suffix
  1206. else: # repo
  1207. git_remote = source_info.get("git_remote", "")
  1208. git_commit = source_info.get("git_commit", "")
  1209. return (
  1210. f"git clone {shlex.quote(git_remote)} {CODE_MOUNT_DIR}"
  1211. f" && git config --global --add safe.directory {CODE_MOUNT_DIR}"
  1212. f" && cd {CODE_MOUNT_DIR} && git checkout {shlex.quote(git_commit)}"
  1213. + fetch_job_artifact
  1214. + chmod_suffix
  1215. )
  1216. def _build_source_init_container(
  1217. fetch_script: str,
  1218. base_url: str,
  1219. api_key_env: dict | None,
  1220. ) -> dict:
  1221. """Build the init container spec that fetches source code into the emptyDir volume.
  1222. Args:
  1223. fetch_script: Shell script to run in the init container.
  1224. base_url: W&B base URL passed as an environment variable.
  1225. api_key_env: Optional WANDB_API_KEY env dict extracted from a main container.
  1226. """
  1227. init_env: list[dict] = [{"name": "WANDB_BASE_URL", "value": base_url}]
  1228. if api_key_env:
  1229. init_env.append(api_key_env)
  1230. return {
  1231. "name": "wandb-source-code-init",
  1232. "image": "wandb/launch-agent:latest",
  1233. "volumeMounts": [
  1234. {"name": "wandb-source-code-volume", "mountPath": CODE_MOUNT_DIR}
  1235. ],
  1236. "env": init_env,
  1237. "command": ["/bin/sh", "-c", fetch_script],
  1238. }
  1239. def _configure_containers_for_code_mount(
  1240. pod: dict,
  1241. project: LaunchProject,
  1242. install_deps: bool,
  1243. job_dir: str,
  1244. ) -> dict | None:
  1245. """Mount the code volume on all main containers and return the first WANDB_API_KEY env entry.
  1246. Args:
  1247. pod: The pod spec dict to modify in place.
  1248. project: The launch project (for workingDir and dep-install config).
  1249. install_deps: Whether to wrap container commands with a dep-install step.
  1250. job_dir: Path to frozen requirements inside the mounted volume.
  1251. """
  1252. api_key_env = None
  1253. for container in yield_containers(pod):
  1254. container.setdefault("volumeMounts", []).append(
  1255. {"name": "wandb-source-code-volume", "mountPath": CODE_MOUNT_DIR}
  1256. )
  1257. container["workingDir"] = project.resolved_working_dir
  1258. # Only install deps when using the auto-assigned default base image.
  1259. # User-provided base images are expected to already have deps.
  1260. if install_deps:
  1261. _set_container_command_with_dep_install(
  1262. container,
  1263. project.resolved_working_dir,
  1264. f"{job_dir}/requirements.frozen.txt",
  1265. )
  1266. if api_key_env is None:
  1267. for env in container.get("env", []):
  1268. if env["name"] == "WANDB_API_KEY":
  1269. api_key_env = env
  1270. break
  1271. return api_key_env
  1272. def apply_code_mount_configuration_emptydir(
  1273. manifest: dict | list, project: LaunchProject, api: Api
  1274. ) -> None:
  1275. """Apply emptyDir code mount configuration when no PVC is available.
  1276. Uses an init container to fetch code into an emptyDir volume, which is
  1277. then mounted into all main containers.
  1278. Args:
  1279. manifest: The manifest to modify.
  1280. project: The launch project.
  1281. api: The internal API instance (for base_url).
  1282. """
  1283. base_url = api.settings("base_url")
  1284. source_type = project.job_source_type
  1285. source_info = project.job_source_info
  1286. install_deps = project._auto_default_base_image
  1287. # Validate before mutating the manifest.
  1288. if source_type not in ("artifact", "repo"):
  1289. raise LaunchError(
  1290. f"Cannot use emptyDir code mount for unknown source type: {source_type!r}"
  1291. )
  1292. job_dir = f"{CODE_MOUNT_DIR}/.job"
  1293. for pod in yield_pods(manifest):
  1294. spec = pod["spec"]
  1295. spec.setdefault("volumes", []).append(
  1296. {"name": "wandb-source-code-volume", "emptyDir": {}}
  1297. )
  1298. api_key_env = _configure_containers_for_code_mount(
  1299. pod, project, install_deps, job_dir
  1300. )
  1301. fetch_script = _build_code_fetch_script(
  1302. source_type, source_info, install_deps, job_dir
  1303. )
  1304. init_container = _build_source_init_container(
  1305. fetch_script, base_url, api_key_env
  1306. )
  1307. spec.setdefault("initContainers", []).append(init_container)