| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817 |
- from __future__ import annotations
- import asyncio
- import json
- import logging
- import os
- import platform
- import re
- import subprocess
- import sys
- from collections import defaultdict
- from collections.abc import Iterator
- from typing import TYPE_CHECKING, Any, cast
- import click
- import wandb
- import wandb.docker as docker
- from wandb import util
- from wandb.apis.internal import Api
- from wandb.sdk.launch.errors import LaunchError
- from wandb.sdk.launch.git_reference import GitReference
- from wandb.sdk.launch.wandb_reference import WandbReference
- from wandb.sdk.wandb_config import Config
- from .builder.templates._wandb_bootstrap import (
- FAILED_PACKAGES_POSTFIX,
- FAILED_PACKAGES_PREFIX,
- )
- FAILED_PACKAGES_REGEX = re.compile(
- f"{re.escape(FAILED_PACKAGES_PREFIX)}(.*){re.escape(FAILED_PACKAGES_POSTFIX)}"
- )
- if TYPE_CHECKING: # pragma: no cover
- from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
- # TODO: this should be restricted to just Git repos and not S3 and stuff like that
- _GIT_URI_REGEX = re.compile(
- r"^[^/|^~|^\.].*(git|bitbucket|dev\.azure\.com|\.visualstudio\.com)"
- )
- _VALID_IP_REGEX = r"^https?://[0-9]+(?:\.[0-9]+){3}(:[0-9]+)?"
- _VALID_PIP_PACKAGE_REGEX = r"^[a-zA-Z0-9_.-]+$"
- _VALID_WANDB_REGEX = r"^https?://(api.)?wandb"
- _WANDB_URI_REGEX = re.compile(r"|".join([_VALID_WANDB_REGEX, _VALID_IP_REGEX]))
- _WANDB_QA_URI_REGEX = re.compile(
- r"^https?://ap\w.qa.wandb"
- ) # for testing, not sure if we wanna keep this
- _WANDB_DEV_URI_REGEX = re.compile(
- r"^https?://ap\w.wandb.test"
- ) # for testing, not sure if we wanna keep this
- _WANDB_LOCAL_DEV_URI_REGEX = re.compile(
- r"^https?://localhost"
- ) # for testing, not sure if we wanna keep this
- API_KEY_REGEX = r"WANDB_API_KEY=\w+(-\w+)?"
- MACRO_REGEX = re.compile(r"\$\{(\w+)\}")
- AZURE_CONTAINER_REGISTRY_URI_REGEX = re.compile(
- r"^(?:https://)?([\w]+)\.azurecr\.io/(?P<repository>[\w\-]+):?(?P<tag>.*)"
- )
- ELASTIC_CONTAINER_REGISTRY_URI_REGEX = re.compile(
- r"^(?:https://)?(?P<account>[\w-]+)\.dkr\.ecr\.(?P<region>[\w-]+)\.amazonaws\.com/(?P<repository>[\.\/\w-]+):?(?P<tag>.*)$"
- )
- GCP_ARTIFACT_REGISTRY_URI_REGEX = re.compile(
- r"^(?:https://)?(?P<region>[\w-]+)-docker\.pkg\.dev/(?P<project>[\w-]+)/(?P<repository>[\w-]+)/?(?P<image_name>[\w-]+)?(?P<tag>:.*)?$",
- re.IGNORECASE,
- )
- S3_URI_RE = re.compile(r"s3://([^/]+)(/(.*))?")
- GCS_URI_RE = re.compile(r"gs://([^/]+)(?:/(.*))?")
- AZURE_BLOB_REGEX = re.compile(
- r"^https://([^\.]+)\.blob\.core\.windows\.net/([^/]+)/?(.*)$"
- )
- ARN_PARTITION_RE = re.compile(r"^arn:([^:]+):[^:]*:[^:]*:[^:]*:[^:]*$")
- PROJECT_SYNCHRONOUS = "SYNCHRONOUS"
- LAUNCH_CONFIG_FILE = "~/.config/wandb/launch-config.yaml"
- LAUNCH_DEFAULT_PROJECT = "model-registry"
- _logger = logging.getLogger(__name__)
- LOG_PREFIX = f"{click.style('launch:', fg='magenta')} "
- MAX_ENV_LENGTHS: dict[str, int] = defaultdict(lambda: 32670)
- MAX_ENV_LENGTHS["SageMakerRunner"] = 512
- CODE_MOUNT_DIR = "/mnt/wandb"
- def load_wandb_config() -> Config:
- """Load wandb config from WANDB_CONFIG environment variable(s).
- The WANDB_CONFIG environment variable is a json string that can contain
- multiple config keys. The WANDB_CONFIG_[0-9]+ environment variables are
- used for environments where there is a limit on the length of environment
- variables. In that case, we shard the contents of WANDB_CONFIG into
- multiple environment variables numbered from 0.
- Returns:
- A dictionary of wandb config values.
- """
- config_str = os.environ.get("WANDB_CONFIG")
- if config_str is None:
- config_str = ""
- idx = 0
- while True:
- chunk = os.environ.get(f"WANDB_CONFIG_{idx}")
- if chunk is None:
- break
- config_str += chunk
- idx += 1
- if idx < 1:
- raise LaunchError(
- "No WANDB_CONFIG or WANDB_CONFIG_[0-9]+ environment variables found"
- )
- wandb_config = Config()
- try:
- env_config = json.loads(config_str)
- except json.JSONDecodeError as e:
- raise LaunchError(f"Failed to parse WANDB_CONFIG: {e}") from e
- wandb_config.update(env_config)
- return wandb_config
- def event_loop_thread_exec(func: Any) -> Any:
- """Wrapper for running any function in an awaitable thread on an event loop.
- Example usage:
- ```
- def my_func(arg1, arg2):
- return arg1 + arg2
- future = event_loop_thread_exec(my_func)(2, 2)
- assert await future == 4
- ```
- The returned function must be called within an active event loop.
- """
- async def wrapper(*args: Any, **kwargs: Any) -> Any:
- loop = asyncio.get_event_loop()
- result = cast(
- Any, await loop.run_in_executor(None, lambda: func(*args, **kwargs))
- )
- return result
- return wrapper
- def _is_wandb_uri(uri: str) -> bool:
- return (
- _WANDB_URI_REGEX.match(uri)
- or _WANDB_DEV_URI_REGEX.match(uri)
- or _WANDB_LOCAL_DEV_URI_REGEX.match(uri)
- or _WANDB_QA_URI_REGEX.match(uri)
- ) is not None
- def _is_wandb_dev_uri(uri: str) -> bool:
- return bool(_WANDB_DEV_URI_REGEX.match(uri))
- def _is_wandb_local_uri(uri: str) -> bool:
- return bool(_WANDB_LOCAL_DEV_URI_REGEX.match(uri))
- def _is_git_uri(uri: str) -> bool:
- return bool(_GIT_URI_REGEX.match(uri))
- def sanitize_wandb_api_key(s: str) -> str:
- return str(re.sub(API_KEY_REGEX, "WANDB_API_KEY", s))
- def get_project_from_job(job: str) -> str | None:
- job_parts = job.split("/")
- if len(job_parts) == 3:
- return job_parts[1]
- return None
- def set_project_entity_defaults(
- uri: str | None,
- job: str | None,
- api: Api,
- project: str | None,
- entity: str | None,
- launch_config: dict[str, Any] | None,
- ) -> tuple[str | None, str]:
- # set the target project and entity if not provided
- source_uri = None
- if uri is not None:
- if _is_wandb_uri(uri):
- _, source_uri, _ = parse_wandb_uri(uri)
- elif _is_git_uri(uri):
- source_uri = os.path.splitext(os.path.basename(uri))[0]
- elif job is not None:
- source_uri = get_project_from_job(job)
- if project is None:
- config_project = None
- if launch_config:
- config_project = launch_config.get("project")
- project = config_project or source_uri or ""
- if entity is None:
- entity = get_default_entity(api, launch_config)
- prefix = ""
- if platform.system() != "Windows" and sys.stdout.encoding == "UTF-8":
- prefix = "🚀 "
- wandb.termlog(
- f"{LOG_PREFIX}{prefix}Launching run into {entity}{'/' + project if project else ''}"
- )
- return project, entity
- def get_default_entity(api: Api, launch_config: dict[str, Any] | None):
- config_entity = None
- if launch_config:
- config_entity = launch_config.get("entity")
- return config_entity or api.default_entity
- def strip_resource_args_and_template_vars(launch_spec: dict[str, Any]) -> None:
- if launch_spec.get("resource_args") and launch_spec.get("template_variables"):
- wandb.termwarn(
- "Launch spec contains both resource_args and template_variables, "
- "only one can be set. Using template_variables."
- )
- launch_spec.pop("resource_args")
- def construct_launch_spec(
- uri: str | None,
- job: str | None,
- api: Api,
- name: str | None,
- project: str | None,
- entity: str | None,
- docker_image: str | None,
- resource: str | None,
- entry_point: list[str] | None,
- version: str | None,
- resource_args: dict[str, Any] | None,
- launch_config: dict[str, Any] | None,
- run_id: str | None,
- repository: str | None,
- author: str | None,
- sweep_id: str | None = None,
- ) -> dict[str, Any]:
- """Construct the launch specification from CLI arguments."""
- # override base config (if supplied) with supplied args
- launch_spec = launch_config if launch_config is not None else {}
- if uri is not None:
- launch_spec["uri"] = uri
- if job is not None:
- launch_spec["job"] = job
- project, entity = set_project_entity_defaults(
- uri,
- job,
- api,
- project,
- entity,
- launch_config,
- )
- launch_spec["entity"] = entity
- if author:
- launch_spec["author"] = author
- launch_spec["project"] = project
- if name:
- launch_spec["name"] = name
- if "docker" not in launch_spec:
- launch_spec["docker"] = {}
- if docker_image:
- launch_spec["docker"]["docker_image"] = docker_image
- if sweep_id: # all runs in a sweep have this set
- launch_spec["sweep_id"] = sweep_id
- if "resource" not in launch_spec:
- launch_spec["resource"] = resource if resource else None
- if "git" not in launch_spec:
- launch_spec["git"] = {}
- if version:
- launch_spec["git"]["version"] = version
- if "overrides" not in launch_spec:
- launch_spec["overrides"] = {}
- if not isinstance(launch_spec["overrides"].get("args", []), list):
- raise LaunchError("override args must be a list of strings")
- if resource_args:
- launch_spec["resource_args"] = resource_args
- if entry_point:
- launch_spec["overrides"]["entry_point"] = entry_point
- if run_id is not None:
- launch_spec["run_id"] = run_id
- if repository:
- launch_config = launch_config or {}
- if launch_config.get("registry"):
- launch_config["registry"]["url"] = repository
- else:
- launch_config["registry"] = {"url": repository}
- # dont send both resource args and template variables
- strip_resource_args_and_template_vars(launch_spec)
- return launch_spec
- def validate_launch_spec_source(launch_spec: dict[str, Any]) -> None:
- job = launch_spec.get("job")
- docker_image = launch_spec.get("docker", {}).get("docker_image")
- if bool(job) == bool(docker_image):
- raise LaunchError(
- "Exactly one of job or docker_image must be specified in the launch spec."
- )
- def parse_wandb_uri(uri: str) -> tuple[str, str, str]:
- """Parse wandb uri to retrieve entity, project and run name."""
- ref = WandbReference.parse(uri)
- if not ref or not ref.entity or not ref.project or not ref.run_id:
- raise LaunchError(f"Trouble parsing wandb uri {uri}")
- return (ref.entity, ref.project, ref.run_id)
- def get_local_python_deps(
- dir: str, filename: str = "requirements.local.txt"
- ) -> str | None:
- try:
- env = os.environ
- with open(os.path.join(dir, filename), "w") as f:
- subprocess.call(["pip", "freeze"], env=env, stdout=f)
- return filename
- except subprocess.CalledProcessError as e:
- wandb.termerror(f"Command failed: {e}")
- return None
- def diff_pip_requirements(req_1: list[str], req_2: list[str]) -> dict[str, str]:
- """Return a list of pip requirements that are not in req_1 but are in req_2."""
- def _parse_req(req: list[str]) -> dict[str, str]:
- # TODO: This can be made more exhaustive, but for 99% of cases this is fine
- # see https://pip.pypa.io/en/stable/reference/requirements-file-format/#example
- d: dict[str, str] = dict()
- for line in req:
- _name: str = None # type: ignore
- _version: str = None # type: ignore
- if line.startswith("#"): # Ignore comments
- continue
- elif "git+" in line or "hg+" in line:
- _name = line.split("#egg=")[1]
- _version = line.split("@")[-1].split("#")[0]
- elif "==" in line:
- _s = line.split("==")
- _name = _s[0].lower()
- _version = _s[1].split("#")[0].strip()
- elif ">=" in line:
- _s = line.split(">=")
- _name = _s[0].lower()
- _version = _s[1].split("#")[0].strip()
- elif ">" in line:
- _s = line.split(">")
- _name = _s[0].lower()
- _version = _s[1].split("#")[0].strip()
- elif re.match(_VALID_PIP_PACKAGE_REGEX, line) is not None:
- _name = line
- else:
- raise ValueError(f"Unable to parse pip requirements file line: {line}")
- if _name is not None:
- assert re.match(_VALID_PIP_PACKAGE_REGEX, _name), (
- f"Invalid pip package name {_name}"
- )
- d[_name] = _version
- return d
- # Use symmetric difference between dict representation to print errors
- try:
- req_1_dict: dict[str, str] = _parse_req(req_1)
- req_2_dict: dict[str, str] = _parse_req(req_2)
- except (AssertionError, ValueError, IndexError, KeyError) as e:
- raise LaunchError(f"Failed to parse pip requirements: {e}")
- diff: list[tuple[str, str]] = []
- for item in set(req_1_dict.items()) ^ set(req_2_dict.items()):
- diff.append(item)
- # Parse through the diff to make it pretty
- pretty_diff: dict[str, str] = {}
- for name, version in diff:
- if pretty_diff.get(name) is None:
- pretty_diff[name] = version
- else:
- pretty_diff[name] = f"v{version} and v{pretty_diff[name]}"
- return pretty_diff
- def validate_wandb_python_deps(
- requirements_file: str | None,
- dir: str,
- ) -> None:
- """Warn if local python dependencies differ from wandb requirements.txt."""
- if requirements_file is not None:
- requirements_path = os.path.join(dir, requirements_file)
- with open(requirements_path) as f:
- wandb_python_deps: list[str] = f.read().splitlines()
- local_python_file = get_local_python_deps(dir)
- if local_python_file is not None:
- local_python_deps_path = os.path.join(dir, local_python_file)
- with open(local_python_deps_path) as f:
- local_python_deps: list[str] = f.read().splitlines()
- diff_pip_requirements(wandb_python_deps, local_python_deps)
- return
- _logger.warning("Unable to validate local python dependencies")
- def apply_patch(patch_string: str, dst_dir: str) -> None:
- """Applies a patch file to a directory."""
- _logger.info("Applying diff.patch")
- with open(os.path.join(dst_dir, "diff.patch"), "w") as fp:
- fp.write(patch_string)
- try:
- subprocess.check_call(
- [
- "patch",
- "-s",
- f"--directory={dst_dir}",
- "-p1",
- "-i",
- "diff.patch",
- ]
- )
- except subprocess.CalledProcessError:
- raise wandb.Error("Failed to apply diff.patch associated with run.")
- def _fetch_git_repo(dst_dir: str, uri: str, version: str | None) -> str | None:
- """Clones the git repo at ``uri`` into ``dst_dir``.
- checks out commit ``version``. Assumes authentication parameters are
- specified by the environment, e.g. by a Git credential helper.
- """
- # We defer importing git until the last moment, because the import requires that the git
- # executable is available on the PATH, so we only want to fail if we actually need it.
- _logger.info("Fetching git repo")
- ref = GitReference(uri, version)
- if ref is None:
- raise LaunchError(f"Unable to parse git uri: {uri}")
- ref.fetch(dst_dir)
- if version is None:
- version = ref.ref
- return version
- def convert_jupyter_notebook_to_script(fname: str, project_dir: str) -> str:
- nbconvert = wandb.util.get_module(
- "nbconvert", "nbformat and nbconvert are required to use launch with notebooks"
- )
- nbformat = wandb.util.get_module(
- "nbformat", "nbformat and nbconvert are required to use launch with notebooks"
- )
- _logger.info("Converting notebook to script")
- new_name = fname.replace(".ipynb", ".py")
- with open(os.path.join(project_dir, fname)) as fh:
- nb = nbformat.reads(fh.read(), nbformat.NO_CONVERT)
- for cell in nb.cells:
- if cell.cell_type == "code":
- source_lines = cell.source.split("\n")
- modified_lines = []
- for line in source_lines:
- if not line.startswith("!"):
- modified_lines.append(line)
- cell.source = "\n".join(modified_lines)
- exporter = nbconvert.PythonExporter()
- source, meta = exporter.from_notebook_node(nb)
- with open(os.path.join(project_dir, new_name), "w+") as fh:
- fh.writelines(source)
- return new_name
- def to_camel_case(maybe_snake_str: str) -> str:
- if "_" not in maybe_snake_str:
- return maybe_snake_str
- components = maybe_snake_str.split("_")
- return "".join(x.title() if x else "_" for x in components)
- def validate_build_and_registry_configs(
- build_config: dict[str, Any], registry_config: dict[str, Any]
- ) -> None:
- build_config_credentials = build_config.get("credentials", {})
- registry_config_credentials = registry_config.get("credentials", {})
- if (
- build_config_credentials
- and registry_config_credentials
- and build_config_credentials != registry_config_credentials
- ):
- raise LaunchError("registry and build config credential mismatch")
- async def get_kube_context_and_api_client(
- kubernetes: Any,
- resource_args: dict[str, Any],
- ) -> tuple[Any, Any]:
- config_file = resource_args.get("configFile")
- context = None
- if config_file is not None or os.path.exists(os.path.expanduser("~/.kube/config")):
- # context only exist in the non-incluster case
- (
- all_contexts,
- active_context,
- ) = kubernetes.config.list_kube_config_contexts(config_file)
- context = None
- if resource_args.get("context"):
- context_name = resource_args["context"]
- for c in all_contexts:
- if c["name"] == context_name:
- context = c
- break
- raise LaunchError(f"Specified context {context_name} was not found.")
- else:
- context = active_context
- # TODO: We should not really be performing this check if the user is not
- # using EKS but I don't see an obvious way to make an eks specific code path
- # right here.
- util.get_module(
- "awscli",
- "awscli is required to load a kubernetes context "
- "from eks. Please run `pip install wandb[launch]` to install it.",
- )
- await kubernetes.config.load_kube_config(config_file, context["name"])
- api_client = await kubernetes.config.new_client_from_config(
- config_file, context=context["name"]
- )
- return context, api_client
- else:
- kubernetes.config.load_incluster_config()
- api_client = kubernetes.client.api_client.ApiClient()
- return context, api_client
- def resolve_build_and_registry_config(
- default_launch_config: dict[str, Any] | None,
- build_config: dict[str, Any] | None,
- registry_config: dict[str, Any] | None,
- ) -> tuple[dict[str, Any], dict[str, Any]]:
- resolved_build_config: dict[str, Any] = {}
- if build_config is None and default_launch_config is not None:
- resolved_build_config = default_launch_config.get("builder", {})
- elif build_config is not None:
- resolved_build_config = build_config
- resolved_registry_config: dict[str, Any] = {}
- if registry_config is None and default_launch_config is not None:
- resolved_registry_config = default_launch_config.get("registry", {})
- elif registry_config is not None:
- resolved_registry_config = registry_config
- validate_build_and_registry_configs(resolved_build_config, resolved_registry_config)
- return resolved_build_config, resolved_registry_config
- def check_logged_in(api: Api) -> bool:
- """Check if a user is logged in.
- Raises an error if the viewer doesn't load (likely a broken API key). Expected time
- cost is 0.1-0.2 seconds.
- """
- res = api.api.viewer()
- if not res:
- raise LaunchError(
- "Could not connect with current API-key. "
- "Please relogin using `wandb login --relogin`"
- " and try again (see `wandb login --help` for more options)"
- )
- return True
- def make_name_dns_safe(name: str) -> str:
- resp = name.replace("_", "-").lower()
- resp = re.sub(r"[^a-z\.\-]", "", resp)
- # Actual length limit is 253, but we want to leave room for the generated suffix
- resp = resp[:200]
- return resp
- def make_k8s_label_safe(value: str) -> str:
- """Return a Kubernetes label/identifier safe string (DNS-1123 label).
- See:
- https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-label-names
- Rules:
- - lowercase alphanumeric and '-'
- - must start and end with an alphanumeric
- - max length 63
- """
- # Normalize common separators first
- safe = value.replace("_", "-").lower()
- # Remove any invalid characters
- safe = re.sub(r"[^a-z0-9\-]", "", safe)
- # Collapse consecutive '-'
- safe = re.sub(r"-+", "-", safe)
- # Trim to 63 and strip leading/trailing '-'
- safe = safe[:63].strip("-")
- if not safe:
- raise LaunchError(f"Invalid value for Kubernetes label: {value}")
- return safe
- def warn_failed_packages_from_build_logs(
- log: str, image_uri: str, api: Api, job_tracker: JobAndRunStatusTracker | None
- ) -> None:
- match = FAILED_PACKAGES_REGEX.search(log)
- if match:
- _msg = f"Failed to install the following packages: {match.group(1)} for image: {image_uri}. Will attempt to launch image without them."
- wandb.termwarn(_msg)
- if job_tracker is not None:
- res = job_tracker.saver.save_contents(
- _msg, "failed-packages.log", "warning"
- )
- api.update_run_queue_item_warning(
- job_tracker.run_queue_item_id,
- "Some packages were not successfully installed during the build",
- "build",
- res,
- )
- def docker_image_exists(docker_image: str, should_raise: bool = False) -> bool:
- """Check if a specific image is already available.
- Optionally raises an exception if the image is not found.
- """
- _logger.info("Checking if base image exists...")
- try:
- docker.run(["docker", "image", "inspect", docker_image])
- return True
- except (docker.DockerError, ValueError):
- if should_raise:
- raise
- _logger.info("Base image not found. Generating new base image")
- return False
- def pull_docker_image(docker_image: str) -> None:
- """Pull the requested docker image."""
- try:
- docker.run(["docker", "pull", docker_image])
- except docker.DockerError as e:
- raise LaunchError(f"Docker server returned error: {e}")
- def macro_sub(original: str, sub_dict: dict[str, str | None]) -> str:
- """Substitute macros in a string.
- Macros occur in the string in the ${macro} format. The macro names are
- substituted with their values from the given dictionary. If a macro
- is not found in the dictionary, it is left unchanged.
- Args:
- original: The string to substitute macros in.
- sub_dict: A dictionary mapping macro names to their values.
- Returns:
- The string with the macros substituted.
- """
- return MACRO_REGEX.sub(
- lambda match: str(sub_dict.get(match.group(1), match.group(0))), original
- )
- def recursive_macro_sub(source: Any, sub_dict: dict[str, str | None]) -> Any:
- """Recursively substitute macros in a parsed JSON or YAML blob.
- Macros occur in strings at leaves of the blob in the ${macro} format.
- The macro names are substituted with their values from the given dictionary.
- If a macro is not found in the dictionary, it is left unchanged.
- Arguments:
- source: The JSON or YAML blob to substitute macros in.
- sub_dict: A dictionary mapping macro names to their values.
- Returns:
- The blob with the macros substituted.
- """
- if isinstance(source, str):
- return macro_sub(source, sub_dict)
- elif isinstance(source, list):
- return [recursive_macro_sub(item, sub_dict) for item in source]
- elif isinstance(source, dict):
- return {
- key: recursive_macro_sub(value, sub_dict) for key, value in source.items()
- }
- else:
- return source
- def fetch_and_validate_template_variables(
- runqueue: Any, fields: dict
- ) -> dict[str, Any]:
- template_variables = {}
- variable_schemas = {}
- for tv in runqueue.template_variables:
- variable_schemas[tv["name"]] = json.loads(tv["schema"])
- for field in fields:
- field_parts = field.split("=")
- if len(field_parts) != 2:
- raise LaunchError(
- f'--set-var value must be in the format "--set-var key1=value1", instead got: {field}'
- )
- key, val = field_parts
- if key not in variable_schemas:
- raise LaunchError(
- f"Queue {runqueue.name} does not support overriding {key}."
- )
- schema = variable_schemas.get(key, {})
- field_type = schema.get("type")
- try:
- if field_type == "integer":
- val = int(val)
- elif field_type == "number":
- val = float(val)
- except ValueError:
- raise LaunchError(f"Value for {key} must be of type {field_type}.")
- template_variables[key] = val
- return template_variables
- def get_entrypoint_file(entrypoint: list[str]) -> str | None:
- """Get the entrypoint file from the given command.
- Args:
- entrypoint (List[str]): List of command and arguments.
- Returns:
- Optional[str]: The entrypoint file if found, otherwise None.
- """
- if not entrypoint:
- return None
- if entrypoint[0].endswith(".py") or entrypoint[0].endswith(".sh"):
- return entrypoint[0]
- if len(entrypoint) < 2:
- return None
- return entrypoint[1]
- def get_current_python_version() -> tuple[str, str]:
- full_version = sys.version.split()[0].split(".")
- major = full_version[0]
- version = ".".join(full_version[:2]) if len(full_version) >= 2 else major + ".0"
- return version, major
- def yield_containers(root: dict | list) -> Iterator[dict]:
- """Yield all container specs in a manifest.
- Recursively traverses the manifest and yields all container specs. Container
- specs are identified by the presence of a "containers" key in the value.
- """
- if isinstance(root, dict):
- for k, v in root.items():
- if k == "containers":
- if isinstance(v, list):
- yield from v
- elif isinstance(v, (dict, list)):
- yield from yield_containers(v)
- elif isinstance(root, list):
- for item in root:
- yield from yield_containers(item)
- def sanitize_identifiers_for_k8s(root: Any) -> None:
- if isinstance(root, list):
- for item in root:
- sanitize_identifiers_for_k8s(item)
- return
- # Only dicts have metadata and nested structures we need to sanitize.
- if not isinstance(root, dict):
- return
- metadata = root.get("metadata")
- if isinstance(metadata, dict) and (name := metadata.get("name")):
- metadata["name"] = make_k8s_label_safe(str(name))
- for container in yield_containers(root):
- if name := container.get("name"):
- container["name"] = make_k8s_label_safe(str(name))
- # nested names
- for key, value in root.items():
- if isinstance(value, (dict, list)):
- sanitize_identifiers_for_k8s(value)
- elif key == "name" and isinstance(value, str):
- root[key] = make_k8s_label_safe(value)
|