| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317 |
- from __future__ import annotations
- import asyncio
- import logging
- import os
- import shlex
- import shutil
- import subprocess
- import sys
- import threading
- from typing import TYPE_CHECKING, Any
- import wandb
- from wandb.sdk.launch.environment.abstract import AbstractEnvironment
- from wandb.sdk.launch.registry.abstract import AbstractRegistry
- from .._project_spec import LaunchProject
- from ..errors import LaunchError
- from ..utils import (
- CODE_MOUNT_DIR,
- LOG_PREFIX,
- MAX_ENV_LENGTHS,
- PROJECT_SYNCHRONOUS,
- _is_wandb_dev_uri,
- _is_wandb_local_uri,
- docker_image_exists,
- event_loop_thread_exec,
- pull_docker_image,
- sanitize_wandb_api_key,
- )
- from .abstract import AbstractRun, AbstractRunner, Status
- if TYPE_CHECKING:
- from wandb.apis.internal import Api
- _logger = logging.getLogger(__name__)
- class LocalSubmittedRun(AbstractRun):
- """Instance of ``AbstractRun`` corresponding to a subprocess launched to run an entry point command locally."""
- def __init__(self) -> None:
- super().__init__()
- self._command_proc: subprocess.Popen | None = None
- self._stdout: str | None = None
- self._terminate_flag: bool = False
- self._thread: threading.Thread | None = None
- def set_command_proc(self, command_proc: subprocess.Popen) -> None:
- self._command_proc = command_proc
- def set_thread(self, thread: threading.Thread) -> None:
- self._thread = thread
- @property
- def id(self) -> str | None:
- if self._command_proc is None:
- return None
- return str(self._command_proc.pid)
- async def wait(self) -> bool:
- assert self._thread is not None
- # if command proc is not set
- # wait for thread to set it
- if self._command_proc is None:
- while self._thread.is_alive():
- await asyncio.sleep(5)
- # command proc can be updated by another thread
- if self._command_proc is not None:
- break # type: ignore # mypy thinks this is unreachable
- else:
- return False
- wait = event_loop_thread_exec(self._command_proc.wait)
- return int(await wait()) == 0
- async def get_logs(self) -> str | None:
- return self._stdout
- async def cancel(self) -> None:
- # thread is set immediately after starting, should always exist
- assert self._thread is not None
- # cancel called before the thread subprocess has started
- # indicates to thread to not start command proc if not already started
- self._terminate_flag = True
- async def get_status(self) -> Status:
- assert self._thread is not None, "Failed to get status, self._thread = None"
- if self._command_proc is None:
- if self._thread.is_alive():
- return Status("running")
- return Status("stopped")
- exit_code = self._command_proc.poll()
- if exit_code is None:
- return Status("running")
- if exit_code == 0:
- return Status("finished")
- return Status("failed")
- class LocalContainerRunner(AbstractRunner):
- """Runner class, uses a project to create a LocallySubmittedRun."""
- def __init__(
- self,
- api: Api,
- backend_config: dict[str, Any],
- environment: AbstractEnvironment,
- registry: AbstractRegistry,
- ) -> None:
- super().__init__(api, backend_config)
- self.environment = environment
- self.registry = registry
- def _populate_docker_args(
- self, launch_project: LaunchProject, image_uri: str
- ) -> dict[str, Any]:
- docker_args: dict[str, Any] = launch_project.fill_macros(image_uri).get(
- "local-container", {}
- )
- if _is_wandb_local_uri(self._api.settings("base_url")):
- if sys.platform == "win32":
- docker_args["net"] = "host"
- else:
- docker_args["network"] = "host"
- if sys.platform == "linux" or sys.platform == "linux2":
- docker_args["add-host"] = "host.docker.internal:host-gateway"
- base_image = launch_project.job_base_image
- if base_image is not None:
- # Mount code into the container and set the working directory.
- if "volume" not in docker_args:
- docker_args["volume"] = []
- docker_args["volume"].append(
- f"{launch_project.project_dir}:{CODE_MOUNT_DIR}"
- )
- docker_args["workdir"] = launch_project.resolved_working_dir
- return docker_args
- async def run(
- self,
- launch_project: LaunchProject,
- image_uri: str,
- ) -> AbstractRun | None:
- docker_args = self._populate_docker_args(launch_project, image_uri)
- synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
- env_vars = launch_project.get_env_vars_dict(
- self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
- )
- # When running against local port, need to swap to local docker host
- if (
- _is_wandb_local_uri(self._api.settings("base_url"))
- and sys.platform == "darwin"
- ):
- _, _, port = self._api.settings("base_url").split(":")
- env_vars["WANDB_BASE_URL"] = f"http://host.docker.internal:{port}"
- elif _is_wandb_dev_uri(self._api.settings("base_url")):
- env_vars["WANDB_BASE_URL"] = "http://host.docker.internal:9001"
- if launch_project.docker_image or launch_project.job_base_image:
- try:
- pull_docker_image(image_uri)
- except Exception as e:
- wandb.termwarn(f"Error attempting to pull docker image {image_uri}")
- if not docker_image_exists(image_uri):
- raise LaunchError(
- f"Failed to pull docker image {image_uri} with error: {e}"
- )
- entrypoint = launch_project.get_job_entry_point()
- entry_cmd = None if entrypoint is None else entrypoint.command
- command_str = " ".join(
- get_docker_command(
- image_uri,
- env_vars,
- docker_args=docker_args,
- entry_cmd=entry_cmd,
- additional_args=launch_project.override_args,
- )
- ).strip()
- sanitized_cmd_str = sanitize_wandb_api_key(command_str)
- _msg = f"{LOG_PREFIX}Launching run in docker with command: {sanitized_cmd_str}"
- wandb.termlog(_msg)
- run = _run_entry_point(command_str, launch_project.project_dir)
- if synchronous:
- await run.wait()
- return run
- def _run_entry_point(command: str, work_dir: str | None) -> AbstractRun:
- """Run an entry point command in a subprocess.
- Arguments:
- command: Entry point command to run
- work_dir: Working directory in which to run the command
- Returns:
- An instance of `LocalSubmittedRun`
- """
- if work_dir is None:
- work_dir = os.getcwd()
- env = os.environ.copy()
- run = LocalSubmittedRun()
- thread = threading.Thread(
- target=_thread_process_runner,
- args=(run, _shell_command(command), work_dir, env),
- )
- run.set_thread(thread)
- thread.start()
- return run
- def _shell_command(command: str) -> list[str]:
- """Return a cross-platform shell invocation for command execution."""
- if os.name == "nt":
- return ["cmd", "/C", command]
- shell = shutil.which("bash") or shutil.which("sh")
- if shell is None:
- raise LaunchError(
- "Could not launch command: no compatible shell found (expected bash or sh)."
- )
- return [shell, "-c", command]
- def _thread_process_runner(
- run: LocalSubmittedRun, args: list[str], work_dir: str, env: dict[str, str]
- ) -> None:
- # cancel was called before we started the subprocess
- if run._terminate_flag:
- return
- # TODO: Make this async
- process = subprocess.Popen(
- args,
- close_fds=True,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- universal_newlines=True,
- bufsize=1,
- cwd=work_dir,
- env=env,
- )
- run.set_command_proc(process)
- run._stdout = ""
- while True:
- # the agent thread could set the terminate flag
- if run._terminate_flag:
- process.terminate() # type: ignore
- chunk = os.read(process.stdout.fileno(), 4096) # type: ignore
- if not chunk:
- break
- index = chunk.find(b"\r")
- decoded_chunk = None
- while not decoded_chunk:
- try:
- decoded_chunk = chunk.decode()
- except UnicodeDecodeError:
- # Multi-byte character cut off, try to get the rest of it
- chunk += os.read(process.stdout.fileno(), 1) # type: ignore
- if index != -1:
- run._stdout += decoded_chunk
- print(chunk.decode(), end="")
- else:
- run._stdout += decoded_chunk + "\r"
- print(chunk.decode(), end="\r")
- def get_docker_command(
- image: str,
- env_vars: dict[str, str],
- entry_cmd: list[str] | None = None,
- docker_args: dict[str, Any] | None = None,
- additional_args: list[str] | None = None,
- ) -> list[str]:
- """Construct the docker command using the image and docker args.
- Arguments:
- image: a Docker image to be run
- env_vars: a dictionary of environment variables for the command
- entry_cmd: the entry point command to run
- docker_args: a dictionary of additional docker args for the command
- """
- docker_path = "docker"
- cmd: list[Any] = [docker_path, "run", "--rm"]
- # hacky handling of env vars, needs to be improved
- for env_key, env_value in env_vars.items():
- cmd += ["-e", f"{shlex.quote(env_key)}={shlex.quote(env_value)}"]
- if docker_args:
- for name, value in docker_args.items():
- if len(name) == 1:
- prefix = "-" + shlex.quote(name)
- else:
- prefix = "--" + shlex.quote(name)
- if isinstance(value, list):
- for v in value:
- cmd += [prefix, shlex.quote(str(v))]
- elif isinstance(value, bool) and value:
- cmd += [prefix]
- else:
- cmd += [prefix, shlex.quote(str(value))]
- if entry_cmd:
- cmd += ["--entrypoint", entry_cmd[0]]
- cmd += [shlex.quote(image)]
- if entry_cmd and len(entry_cmd) > 1:
- cmd += entry_cmd[1:]
- if additional_args:
- cmd += additional_args
- return cmd
- def join(split_command: list[str]) -> str:
- """Return a shell-escaped string from *split_command*."""
- return " ".join(shlex.quote(arg) for arg in split_command)
|