from __future__ import annotations import json import logging import os import shutil import subprocess from typing import Any from wandb.docker import names from wandb.errors import Error class DockerError(Error): """Raised when attempting to execute a docker command.""" def __init__( self, command_launched: list[str], return_code: int, stdout: bytes | None = None, stderr: bytes | None = None, ) -> None: command_launched_str = " ".join(command_launched) error_msg = ( f"The docker command executed was `{command_launched_str}`.\n" f"It returned with code {return_code}\n" ) if stdout is not None: error_msg += f"The content of stdout is '{stdout.decode()}'\n" else: error_msg += ( "The content of stdout can be found above the " "stacktrace (it wasn't captured).\n" ) if stderr is not None: error_msg += f"The content of stderr is '{stderr.decode()}'\n" else: error_msg += ( "The content of stderr can be found above the " "stacktrace (it wasn't captured)." ) super().__init__(error_msg) entrypoint = os.path.join( os.path.dirname(os.path.abspath(__file__)), "wandb-entrypoint.sh" ) log = logging.getLogger(__name__) def shell(cmd: list[str]) -> str | None: """Simple wrapper for calling docker,. returning None on error and the output on success """ try: return ( subprocess.check_output(["docker"] + cmd, stderr=subprocess.STDOUT) .decode("utf8") .strip() ) except subprocess.CalledProcessError as e: print(e) # noqa: T201 return None _buildx_installed = None def is_buildx_installed() -> bool: """Return `True` if docker buildx is installed and working.""" global _buildx_installed if _buildx_installed is not None: return _buildx_installed # type: ignore if not shutil.which("docker"): _buildx_installed = False else: help_output = shell(["buildx", "--help"]) _buildx_installed = help_output is not None and "buildx" in help_output return _buildx_installed def is_docker_installed() -> bool: """Return `True` if docker is installed and working, else `False`.""" try: # Run the docker --version command result = subprocess.run( ["docker", "--version"], capture_output=True, ) except FileNotFoundError: # If docker command is not found return False else: return result.returncode == 0 def build( tags: list[str], file: str, context_path: str, platform: str | None = None ) -> str: use_buildx = is_buildx_installed() command = ["buildx", "build"] if use_buildx else ["build"] command += ["--load"] if should_add_load_argument(platform) and use_buildx else [] if platform: command += ["--platform", platform] build_tags = [] for tag in tags: build_tags += ["-t", tag] args = ["docker"] + command + build_tags + ["-f", file, context_path] stdout = run_command_live_output( args, ) return stdout def should_add_load_argument(platform: str | None) -> bool: # the load option does not work when multiple platforms are specified: # https://github.com/docker/buildx/issues/59 return bool(platform is None or platform and "," not in platform) def run_command_live_output(args: list[Any]) -> str: with subprocess.Popen( args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True, bufsize=1, ) as process: stdout = "" while True: chunk = os.read(process.stdout.fileno(), 4096) # type: ignore if not chunk: break index = chunk.find(b"\r") if index != -1: print(chunk.decode(), end="") # noqa: T201 else: stdout += chunk.decode() print(chunk.decode(), end="\r") # noqa: T201 print(stdout) # noqa: T201 return_code = process.wait() if return_code != 0: raise DockerError(args, return_code, stdout.encode()) return stdout def run( args: list[Any], capture_stdout: bool = True, capture_stderr: bool = True, input: bytes | None = None, return_stderr: bool = False, env: dict[str, str] | None = None, ) -> str | tuple[str, str]: args = [str(x) for x in args] subprocess_env = dict(os.environ) subprocess_env.update(env or {}) if args[1] == "buildx": subprocess_env["DOCKER_CLI_EXPERIMENTAL"] = "enabled" stdout_dest: int | None = subprocess.PIPE if capture_stdout else None stderr_dest: int | None = subprocess.PIPE if capture_stderr else None completed_process = subprocess.run( args, input=input, stdout=stdout_dest, stderr=stderr_dest, env=subprocess_env ) if completed_process.returncode != 0: raise DockerError( args, completed_process.returncode, completed_process.stdout, completed_process.stderr, ) if return_stderr: return ( _post_process_stream(completed_process.stdout), _post_process_stream(completed_process.stderr), ) else: return _post_process_stream(completed_process.stdout) def _post_process_stream(stream: bytes | None) -> str: if stream is None: return "" decoded_stream = stream.decode() if len(decoded_stream) != 0 and decoded_stream[-1] == "\n": decoded_stream = decoded_stream[:-1] return decoded_stream def default_image(gpu: bool = False) -> str: tag = "all" if not gpu: tag += "-cpu" return f"wandb/deepo:{tag}" def parse_repository_tag(repo_name: str) -> tuple[str, str | None]: parts = repo_name.rsplit("@", 1) if len(parts) == 2: return parts[0], parts[1] parts = repo_name.rsplit(":", 1) if len(parts) == 2 and "/" not in parts[1]: return parts[0], parts[1] return repo_name, None def parse(image_name: str) -> tuple[str, str, str]: repository, tag = parse_repository_tag(image_name) registry, repo_name = names.resolve_repository_name(repository) if registry == "docker.io": registry = "index.docker.io" return registry, repo_name, (tag or "latest") def image_id_from_registry(image_name: str) -> str | None: """Query the image manifest to get its full ID including the digest. Args: image_name: The image name, such as "wandb/local". Returns: The image name followed by its digest, like "wandb/local@sha256:...". """ # https://docs.docker.com/reference/cli/docker/buildx/imagetools/inspect inspect_cmd = ["buildx", "imagetools", "inspect", image_name] format_args = ["--format", r"{{.Name}}@{{.Manifest.Digest}}"] return shell([*inspect_cmd, *format_args]) def image_id(image_name: str) -> str | None: """Retrieve the image id from the local docker daemon or remote registry.""" if "@sha256:" in image_name: return image_name else: digests = shell(["inspect", image_name, "--format", "{{json .RepoDigests}}"]) if digests is None: return image_id_from_registry(image_name) try: return json.loads(digests)[0] except (ValueError, IndexError): return image_id_from_registry(image_name) def get_image_uid(image_name: str) -> int: """Retrieve the image default uid through brute force.""" image_uid = shell(["run", image_name, "id", "-u"]) return int(image_uid) if image_uid else -1 def push(image: str, tag: str) -> str | None: """Push an image to a remote registry.""" return shell(["push", f"{image}:{tag}"]) def login(username: str, password: str, registry: str) -> str | None: """Login to a registry.""" return shell(["login", "--username", username, "--password", password, registry]) def tag(image_name: str, tag: str) -> str | None: """Tag an image.""" return shell(["tag", image_name, tag]) __all__ = [ "shell", "build", "run", "image_id", "image_id_from_registry", "is_docker_installed", "parse", "parse_repository_tag", "default_image", "get_image_uid", "push", "login", "tag", ]