"""Convert launch arguments into a runnable wandb launch script. Arguments can come from a launch spec or call to wandb launch. """ from __future__ import annotations import enum import json import logging import os import shlex import shutil import tempfile from copy import deepcopy from typing import TYPE_CHECKING, Any, cast import wandb from wandb.apis.internal import Api from wandb.errors import CommError from wandb.sdk.launch.utils import get_entrypoint_file from wandb.sdk.lib.runid import generate_id from .errors import LaunchError from .utils import CODE_MOUNT_DIR, LOG_PREFIX, recursive_macro_sub if TYPE_CHECKING: from wandb.sdk.artifacts.artifact import Artifact _logger = logging.getLogger(__name__) # need to make user root for sagemaker, so users have access to /opt/ml directories # that let users create artifacts and access input data RESOURCE_UID_MAP = {"local": 1000, "sagemaker": 0} IMAGE_TAG_MAX_LENGTH = 32 class LaunchSource(enum.IntEnum): """Enumeration of possible sources for a launch project. Attributes: DOCKER: Source is a Docker image. This can happen if a user runs `wandb launch -d `. JOB: Source is a job. This is standard case. SCHEDULER: Source is a wandb sweep scheduler command. """ DOCKER = 1 JOB = 2 SCHEDULER = 3 class LaunchProject: """A launch project specification. The LaunchProject is initialized from a raw launch spec an internal API object. The project encapsulates logic for taking a launch spec and converting it into the executable code. The LaunchProject needs to ultimately produce a full container spec for execution in docker, k8s, sagemaker, or vertex. This container spec includes: - container image uri - environment variables for configuring wandb etc. - entrypoint command and arguments - additional arguments specific to the target resource (e.g. instance type, node selector) This class is stateful and certain methods can only be called after `LaunchProject.fetch_and_validate_project()` has been called. Notes on the entrypoint: - The entrypoint is the command that will be run inside the container. - The LaunchProject stores two entrypoints - The job entrypoint is the entrypoint specified in the job's config. - The override entrypoint is the entrypoint specified in the launch spec. - The override entrypoint takes precedence over the job entrypoint. """ # This init is way to long, and there are too many attributes on this sucker. def __init__( self, uri: str | None, job: str | None, api: Api, launch_spec: dict[str, Any], target_entity: str, target_project: str, name: str | None, docker_config: dict[str, Any], git_info: dict[str, str], overrides: dict[str, Any], resource: str, resource_args: dict[str, Any], run_id: str | None, sweep_id: str | None = None, ): self.uri = uri self.job = job if job is not None: wandb.termlog(f"{LOG_PREFIX}Launching job: {job}") self._job_artifact: Artifact | None = None self.api = api self.launch_spec = launch_spec self.target_entity = target_entity self.target_project = target_project.lower() self.name = name # TODO: replace with run_id # the builder key can be passed in through the resource args # but these resource_args are then passed to the appropriate # runner, so we need to pop the builder key out resource_args_copy = deepcopy(resource_args) self._resource_args_build = resource_args_copy.get(resource, {}).pop( "builder", {} ) self.resource = resource self.resource_args = resource_args_copy self.sweep_id = sweep_id self.author = launch_spec.get("author") self.python_version: str | None = launch_spec.get("python_version") self._job_dockerfile: str | None = None self._job_build_context: str | None = None self._job_base_image: str | None = None self.accelerator_base_image: str | None = self._resource_args_build.get( "accelerator", {} ).get("base_image") or self._resource_args_build.get("cuda", {}).get( "base_image" ) self.docker_image: str | None = docker_config.get( "docker_image" ) or launch_spec.get("image_uri") # type: ignore [assignment] self.docker_user_id = docker_config.get("user_id", 1000) self._entry_point: EntryPoint | None = ( None # todo: keep multiple entrypoint support? ) self.init_overrides(overrides) self.init_source() self.init_git(git_info) self.deps_type: str | None = None self._runtime: str | None = None self.run_id = run_id or generate_id() self._queue_name: str | None = None self._queue_entity: str | None = None self._run_queue_item_id: str | None = None self._job_source_type: str | None = None self._job_source_info: dict[str, Any] = {} self._auto_default_base_image: bool = False def init_source(self) -> None: if self.docker_image is not None: self.source = LaunchSource.DOCKER self.project_dir = None elif self.job is not None: self.source = LaunchSource.JOB self.project_dir = tempfile.mkdtemp() elif self.uri and self.uri.startswith("placeholder"): self.source = LaunchSource.SCHEDULER self.project_dir = os.getcwd() self._entry_point = self.override_entrypoint def change_project_dir(self, new_dir: str) -> None: """Change the project directory to a new directory.""" # Copy the contents of the old project dir to the new project dir. old_dir = self.project_dir if old_dir is not None: shutil.copytree( old_dir, new_dir, symlinks=True, dirs_exist_ok=True, ignore=shutil.ignore_patterns("fsmonitor--daemon.ipc", ".git"), ) shutil.rmtree(old_dir) self.project_dir = new_dir def init_git(self, git_info: dict[str, str]) -> None: self.git_version = git_info.get("version") self.git_repo = git_info.get("repo") def init_overrides(self, overrides: dict[str, Any]) -> None: """Initialize override attributes for a launch project.""" self.overrides = overrides self.override_args: list[str] = overrides.get("args", []) self.override_config: dict[str, Any] = overrides.get("run_config", {}) self.override_artifacts: dict[str, Any] = overrides.get("artifacts", {}) self.override_files: dict[str, Any] = overrides.get("files", {}) self.override_entrypoint: EntryPoint | None = None self.override_dockerfile: str | None = overrides.get("dockerfile") override_entrypoint = overrides.get("entry_point") if override_entrypoint: _logger.info("Adding override entry point") self.override_entrypoint = EntryPoint( name=get_entrypoint_file(override_entrypoint), command=override_entrypoint, ) override_working_dir = overrides.get("working_dir") self.resolved_working_dir: str = ( f"{CODE_MOUNT_DIR}/{override_working_dir}" if override_working_dir else CODE_MOUNT_DIR ) def __repr__(self) -> str: """String representation of LaunchProject.""" if self.source == LaunchSource.JOB: return f"{self.job}" return f"{self.uri}" @classmethod def from_spec(cls, launch_spec: dict[str, Any], api: Api) -> LaunchProject: """Constructs a LaunchProject instance using a launch spec. Arguments: launch_spec: Dictionary representation of launch spec api: Instance of wandb.apis.internal Api Returns: An initialized `LaunchProject` object """ name: str | None = None if launch_spec.get("name"): name = launch_spec["name"] return LaunchProject( launch_spec.get("uri"), launch_spec.get("job"), api, launch_spec, launch_spec["entity"], launch_spec["project"], name, launch_spec.get("docker", {}), launch_spec.get("git", {}), launch_spec.get("overrides", {}), launch_spec.get("resource"), # type: ignore [arg-type] launch_spec.get("resource_args", {}), launch_spec.get("run_id"), launch_spec.get("sweep_id", {}), ) @property def job_dockerfile(self) -> str | None: return self._job_dockerfile @property def job_build_context(self) -> str | None: return self._job_build_context @property def job_base_image(self) -> str | None: return self._job_base_image def set_job_dockerfile(self, dockerfile: str) -> None: self._job_dockerfile = dockerfile def set_job_build_context(self, build_context: str) -> None: self._job_build_context = build_context def set_job_base_image(self, base_image: str) -> None: self._job_base_image = base_image @property def image_name(self) -> str: if self.job_base_image is not None: return self.job_base_image if self.docker_image is not None: return self.docker_image elif self.uri is not None: cleaned_uri = self.uri.replace("https://", "/") first_sep = cleaned_uri.find("/") shortened_uri = cleaned_uri[first_sep:] return wandb.util.make_docker_image_name_safe(shortened_uri) else: # this will always pass since one of these 3 is required assert self.job is not None return wandb.util.make_docker_image_name_safe(self.job.split(":")[0]) @property def queue_name(self) -> str | None: return self._queue_name @queue_name.setter def queue_name(self, value: str) -> None: self._queue_name = value @property def queue_entity(self) -> str | None: return self._queue_entity @queue_entity.setter def queue_entity(self, value: str) -> None: self._queue_entity = value @property def run_queue_item_id(self) -> str | None: return self._run_queue_item_id @run_queue_item_id.setter def run_queue_item_id(self, value: str) -> None: self._run_queue_item_id = value @property def job_source_type(self) -> str | None: return self._job_source_type def set_job_source_type(self, source_type: str) -> None: self._job_source_type = source_type @property def job_source_info(self) -> dict[str, Any]: return self._job_source_info def set_job_source_info(self, source_info: dict[str, Any]) -> None: self._job_source_info = source_info def fill_macros(self, image: str) -> dict[str, Any]: """Substitute values for macros in resource arguments. Certain macros can be used in resource args. These macros allow the user to set resource args dynamically in the context of the run being launched. The macros are given in the ${macro} format. The following macros are currently supported: ${project_name} - the name of the project the run is being launched to. ${entity_name} - the owner of the project the run being launched to. ${run_id} - the id of the run being launched. ${run_name} - the name of the run that is launching. ${image_uri} - the URI of the container image for this run. Additionally, you may use ${} to refer to the value of any environment variables that you plan to set in the environment of any agents that will receive these resource args. Calling this method will overwrite the contents of self.resource_args with the substituted values. Args: image (str): The image name to fill in for ${wandb-image}. Returns: Dict[str, Any]: The resource args with all macros filled in. """ update_dict = { "project_name": self.target_project, "entity_name": self.target_entity, "run_id": self.run_id, "run_name": self.name, "image_uri": image, "author": self.author, } update_dict.update(os.environ) result = recursive_macro_sub(self.resource_args, update_dict) # recursive_macro_sub given a dict returns a dict with the same keys # but with other input types behaves differently. The cast is for mypy. return cast(dict[str, Any], result) def build_required(self) -> bool: """Checks the source to see if a build is required.""" if self.job_base_image is not None: return False return self.source != LaunchSource.JOB @property def docker_image(self) -> str | None: """Returns the Docker image associated with this LaunchProject. This will only be set if an image_uri is being run outside a job. Returns: Optional[str]: The Docker image or None if not specified. """ if self._docker_image: return self._docker_image return None @docker_image.setter def docker_image(self, value: str) -> None: """Sets the Docker image for the project. Args: value (str): The Docker image to set. Returns: None """ self._docker_image = value self._ensure_not_docker_image_and_local_process() def get_job_entry_point(self) -> EntryPoint | None: """Returns the job entrypoint for the project.""" # assuming project only has 1 entry point, pull that out # tmp fn until we figure out if we want to support multiple entry points or not if not self._entry_point: if not self.docker_image and not self.job_base_image: raise LaunchError( "Project must have at least one entry point unless docker image is specified." ) return None return self._entry_point def set_job_entry_point(self, command: list[str]) -> EntryPoint: """Set job entrypoint for the project.""" assert self._entry_point is None, ( "Cannot set entry point twice. Use LaunchProject.override_entrypoint" ) new_entrypoint = EntryPoint(name=command[-1], command=command) self._entry_point = new_entrypoint return new_entrypoint def fetch_and_validate_project(self) -> None: """Fetches a project into a local directory, adds the config values to the directory, and validates the first entrypoint for the project. Arguments: launch_project: LaunchProject to fetch and validate. api: Instance of wandb.apis.internal Api Returns: A validated `LaunchProject` object. """ if self.source == LaunchSource.DOCKER: return elif self.source == LaunchSource.JOB: self._fetch_job() assert self.project_dir is not None # Let's make sure we document this very clearly. def get_image_source_string(self) -> str: """Returns a unique string identifying the source of an image.""" if self.source == LaunchSource.JOB: assert self._job_artifact is not None return f"{self._job_artifact.name}:v{self._job_artifact.version}" elif self.source == LaunchSource.DOCKER: assert isinstance(self.docker_image, str) return self.docker_image else: raise LaunchError( "Unknown source type when determining image source string" ) def _ensure_not_docker_image_and_local_process(self) -> None: """Ensure that docker image is not specified with local-process resource runner. Raises: LaunchError: If docker image is specified with local-process resource runner. """ if self.docker_image is not None and self.resource == "local-process": raise LaunchError( "Cannot specify docker image with local-process resource runner" ) def _fetch_job(self) -> None: """Fetches the job details from the public API and configures the launch project. Raises: LaunchError: If there is an error accessing the job. """ public_api = wandb.apis.public.Api() job_dir = tempfile.mkdtemp() try: job = public_api.job(self.job, path=job_dir) except CommError as e: msg = e.message raise LaunchError( f"Error accessing job {self.job}: {msg} on {public_api.settings.get('base_url')}" ) job.configure_launch_project(self) # Why is this a method of the job? self._job_artifact = job._job_artifact def get_env_vars_dict(self, api: Api, max_env_length: int) -> dict[str, str]: """Generate environment variables for the project. Arguments: launch_project: LaunchProject to generate environment variables for. Returns: Dictionary of environment variables. """ env_vars = {} env_vars["WANDB_BASE_URL"] = api.settings("base_url") override_api_key = self.launch_spec.get("_wandb_api_key") env_vars["WANDB_API_KEY"] = override_api_key or api.api_key if self.target_project: env_vars["WANDB_PROJECT"] = self.target_project env_vars["WANDB_ENTITY"] = self.target_entity env_vars["WANDB_LAUNCH"] = "True" env_vars["WANDB_RUN_ID"] = self.run_id if self.docker_image: env_vars["WANDB_DOCKER"] = self.docker_image if self.name is not None: env_vars["WANDB_NAME"] = self.name if "author" in self.launch_spec and not override_api_key: env_vars["WANDB_USERNAME"] = self.launch_spec["author"] if self.sweep_id: env_vars["WANDB_SWEEP_ID"] = self.sweep_id if self.launch_spec.get("_resume_count", 0) > 0: env_vars["WANDB_RESUME"] = "allow" if self.queue_name: env_vars[wandb.env.LAUNCH_QUEUE_NAME] = self.queue_name if self.queue_entity: env_vars[wandb.env.LAUNCH_QUEUE_ENTITY] = self.queue_entity if self.run_queue_item_id: env_vars[wandb.env.LAUNCH_TRACE_ID] = self.run_queue_item_id _inject_wandb_config_env_vars(self.override_config, env_vars, max_env_length) _inject_file_overrides_env_vars(self.override_files, env_vars, max_env_length) artifacts = {} # if we're spinning up a launch process from a job # we should tell the run to use that artifact if self.job: artifacts = {wandb.util.LAUNCH_JOB_ARTIFACT_SLOT_NAME: self.job} env_vars["WANDB_ARTIFACTS"] = json.dumps( {**artifacts, **self.override_artifacts} ) return env_vars def parse_existing_requirements(self) -> str: from packaging.requirements import InvalidRequirement, Requirement requirements_line = "" assert self.project_dir is not None base_requirements = os.path.join(self.project_dir, "requirements.txt") if os.path.exists(base_requirements): include_only = set() with open(base_requirements) as f2: for line in f2: if line.strip() == "": continue try: req = Requirement(line) name = req.name.lower() include_only.add(shlex.quote(name)) except InvalidRequirement: _logger.warning( "Unable to parse line %s in requirements.txt", line, exc_info=True, ) continue requirements_line += "WANDB_ONLY_INCLUDE={} ".format(",".join(include_only)) if "wandb" not in requirements_line: wandb.termwarn(f"{LOG_PREFIX}wandb is not present in requirements.txt.") return requirements_line class EntryPoint: """An entry point into a wandb launch specification.""" def __init__(self, name: str | None, command: list[str]): self.name = name self.command = command def update_entrypoint_path(self, new_path: str) -> None: """Updates the entrypoint path to a new path.""" if len(self.command) == 2 and ( self.command[0].startswith("python") or self.command[0] == "bash" ): self.command[1] = new_path def _inject_wandb_config_env_vars( config: dict[str, Any], env_dict: dict[str, Any], maximum_env_length: int ) -> None: str_config = json.dumps(config) if len(str_config) <= maximum_env_length: env_dict["WANDB_CONFIG"] = str_config return chunks = [ str_config[i : i + maximum_env_length] for i in range(0, len(str_config), maximum_env_length) ] config_chunks_dict = {f"WANDB_CONFIG_{i}": chunk for i, chunk in enumerate(chunks)} env_dict.update(config_chunks_dict) def _inject_file_overrides_env_vars( overrides: dict[str, Any], env_dict: dict[str, Any], maximum_env_length: int ) -> None: str_overrides = json.dumps(overrides) if len(str_overrides) <= maximum_env_length: env_dict["WANDB_LAUNCH_FILE_OVERRIDES"] = str_overrides return chunks = [ str_overrides[i : i + maximum_env_length] for i in range(0, len(str_overrides), maximum_env_length) ] overrides_chunks_dict = { f"WANDB_LAUNCH_FILE_OVERRIDES_{i}": chunk for i, chunk in enumerate(chunks) } env_dict.update(overrides_chunks_dict)