local_container.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. from __future__ import annotations
  2. import asyncio
  3. import logging
  4. import os
  5. import shlex
  6. import shutil
  7. import subprocess
  8. import sys
  9. import threading
  10. from typing import TYPE_CHECKING, Any
  11. import wandb
  12. from wandb.sdk.launch.environment.abstract import AbstractEnvironment
  13. from wandb.sdk.launch.registry.abstract import AbstractRegistry
  14. from .._project_spec import LaunchProject
  15. from ..errors import LaunchError
  16. from ..utils import (
  17. CODE_MOUNT_DIR,
  18. LOG_PREFIX,
  19. MAX_ENV_LENGTHS,
  20. PROJECT_SYNCHRONOUS,
  21. _is_wandb_dev_uri,
  22. _is_wandb_local_uri,
  23. docker_image_exists,
  24. event_loop_thread_exec,
  25. pull_docker_image,
  26. sanitize_wandb_api_key,
  27. )
  28. from .abstract import AbstractRun, AbstractRunner, Status
  29. if TYPE_CHECKING:
  30. from wandb.apis.internal import Api
  31. _logger = logging.getLogger(__name__)
  32. class LocalSubmittedRun(AbstractRun):
  33. """Instance of ``AbstractRun`` corresponding to a subprocess launched to run an entry point command locally."""
  34. def __init__(self) -> None:
  35. super().__init__()
  36. self._command_proc: subprocess.Popen | None = None
  37. self._stdout: str | None = None
  38. self._terminate_flag: bool = False
  39. self._thread: threading.Thread | None = None
  40. def set_command_proc(self, command_proc: subprocess.Popen) -> None:
  41. self._command_proc = command_proc
  42. def set_thread(self, thread: threading.Thread) -> None:
  43. self._thread = thread
  44. @property
  45. def id(self) -> str | None:
  46. if self._command_proc is None:
  47. return None
  48. return str(self._command_proc.pid)
  49. async def wait(self) -> bool:
  50. assert self._thread is not None
  51. # if command proc is not set
  52. # wait for thread to set it
  53. if self._command_proc is None:
  54. while self._thread.is_alive():
  55. await asyncio.sleep(5)
  56. # command proc can be updated by another thread
  57. if self._command_proc is not None:
  58. break # type: ignore # mypy thinks this is unreachable
  59. else:
  60. return False
  61. wait = event_loop_thread_exec(self._command_proc.wait)
  62. return int(await wait()) == 0
  63. async def get_logs(self) -> str | None:
  64. return self._stdout
  65. async def cancel(self) -> None:
  66. # thread is set immediately after starting, should always exist
  67. assert self._thread is not None
  68. # cancel called before the thread subprocess has started
  69. # indicates to thread to not start command proc if not already started
  70. self._terminate_flag = True
  71. async def get_status(self) -> Status:
  72. assert self._thread is not None, "Failed to get status, self._thread = None"
  73. if self._command_proc is None:
  74. if self._thread.is_alive():
  75. return Status("running")
  76. return Status("stopped")
  77. exit_code = self._command_proc.poll()
  78. if exit_code is None:
  79. return Status("running")
  80. if exit_code == 0:
  81. return Status("finished")
  82. return Status("failed")
  83. class LocalContainerRunner(AbstractRunner):
  84. """Runner class, uses a project to create a LocallySubmittedRun."""
  85. def __init__(
  86. self,
  87. api: Api,
  88. backend_config: dict[str, Any],
  89. environment: AbstractEnvironment,
  90. registry: AbstractRegistry,
  91. ) -> None:
  92. super().__init__(api, backend_config)
  93. self.environment = environment
  94. self.registry = registry
  95. def _populate_docker_args(
  96. self, launch_project: LaunchProject, image_uri: str
  97. ) -> dict[str, Any]:
  98. docker_args: dict[str, Any] = launch_project.fill_macros(image_uri).get(
  99. "local-container", {}
  100. )
  101. if _is_wandb_local_uri(self._api.settings("base_url")):
  102. if sys.platform == "win32":
  103. docker_args["net"] = "host"
  104. else:
  105. docker_args["network"] = "host"
  106. if sys.platform == "linux" or sys.platform == "linux2":
  107. docker_args["add-host"] = "host.docker.internal:host-gateway"
  108. base_image = launch_project.job_base_image
  109. if base_image is not None:
  110. # Mount code into the container and set the working directory.
  111. if "volume" not in docker_args:
  112. docker_args["volume"] = []
  113. docker_args["volume"].append(
  114. f"{launch_project.project_dir}:{CODE_MOUNT_DIR}"
  115. )
  116. docker_args["workdir"] = launch_project.resolved_working_dir
  117. return docker_args
  118. async def run(
  119. self,
  120. launch_project: LaunchProject,
  121. image_uri: str,
  122. ) -> AbstractRun | None:
  123. docker_args = self._populate_docker_args(launch_project, image_uri)
  124. synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
  125. env_vars = launch_project.get_env_vars_dict(
  126. self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
  127. )
  128. # When running against local port, need to swap to local docker host
  129. if (
  130. _is_wandb_local_uri(self._api.settings("base_url"))
  131. and sys.platform == "darwin"
  132. ):
  133. _, _, port = self._api.settings("base_url").split(":")
  134. env_vars["WANDB_BASE_URL"] = f"http://host.docker.internal:{port}"
  135. elif _is_wandb_dev_uri(self._api.settings("base_url")):
  136. env_vars["WANDB_BASE_URL"] = "http://host.docker.internal:9001"
  137. if launch_project.docker_image or launch_project.job_base_image:
  138. try:
  139. pull_docker_image(image_uri)
  140. except Exception as e:
  141. wandb.termwarn(f"Error attempting to pull docker image {image_uri}")
  142. if not docker_image_exists(image_uri):
  143. raise LaunchError(
  144. f"Failed to pull docker image {image_uri} with error: {e}"
  145. )
  146. entrypoint = launch_project.get_job_entry_point()
  147. entry_cmd = None if entrypoint is None else entrypoint.command
  148. command_str = " ".join(
  149. get_docker_command(
  150. image_uri,
  151. env_vars,
  152. docker_args=docker_args,
  153. entry_cmd=entry_cmd,
  154. additional_args=launch_project.override_args,
  155. )
  156. ).strip()
  157. sanitized_cmd_str = sanitize_wandb_api_key(command_str)
  158. _msg = f"{LOG_PREFIX}Launching run in docker with command: {sanitized_cmd_str}"
  159. wandb.termlog(_msg)
  160. run = _run_entry_point(command_str, launch_project.project_dir)
  161. if synchronous:
  162. await run.wait()
  163. return run
  164. def _run_entry_point(command: str, work_dir: str | None) -> AbstractRun:
  165. """Run an entry point command in a subprocess.
  166. Arguments:
  167. command: Entry point command to run
  168. work_dir: Working directory in which to run the command
  169. Returns:
  170. An instance of `LocalSubmittedRun`
  171. """
  172. if work_dir is None:
  173. work_dir = os.getcwd()
  174. env = os.environ.copy()
  175. run = LocalSubmittedRun()
  176. thread = threading.Thread(
  177. target=_thread_process_runner,
  178. args=(run, _shell_command(command), work_dir, env),
  179. )
  180. run.set_thread(thread)
  181. thread.start()
  182. return run
  183. def _shell_command(command: str) -> list[str]:
  184. """Return a cross-platform shell invocation for command execution."""
  185. if os.name == "nt":
  186. return ["cmd", "/C", command]
  187. shell = shutil.which("bash") or shutil.which("sh")
  188. if shell is None:
  189. raise LaunchError(
  190. "Could not launch command: no compatible shell found (expected bash or sh)."
  191. )
  192. return [shell, "-c", command]
  193. def _thread_process_runner(
  194. run: LocalSubmittedRun, args: list[str], work_dir: str, env: dict[str, str]
  195. ) -> None:
  196. # cancel was called before we started the subprocess
  197. if run._terminate_flag:
  198. return
  199. # TODO: Make this async
  200. process = subprocess.Popen(
  201. args,
  202. close_fds=True,
  203. stdout=subprocess.PIPE,
  204. stderr=subprocess.STDOUT,
  205. universal_newlines=True,
  206. bufsize=1,
  207. cwd=work_dir,
  208. env=env,
  209. )
  210. run.set_command_proc(process)
  211. run._stdout = ""
  212. while True:
  213. # the agent thread could set the terminate flag
  214. if run._terminate_flag:
  215. process.terminate() # type: ignore
  216. chunk = os.read(process.stdout.fileno(), 4096) # type: ignore
  217. if not chunk:
  218. break
  219. index = chunk.find(b"\r")
  220. decoded_chunk = None
  221. while not decoded_chunk:
  222. try:
  223. decoded_chunk = chunk.decode()
  224. except UnicodeDecodeError:
  225. # Multi-byte character cut off, try to get the rest of it
  226. chunk += os.read(process.stdout.fileno(), 1) # type: ignore
  227. if index != -1:
  228. run._stdout += decoded_chunk
  229. print(chunk.decode(), end="")
  230. else:
  231. run._stdout += decoded_chunk + "\r"
  232. print(chunk.decode(), end="\r")
  233. def get_docker_command(
  234. image: str,
  235. env_vars: dict[str, str],
  236. entry_cmd: list[str] | None = None,
  237. docker_args: dict[str, Any] | None = None,
  238. additional_args: list[str] | None = None,
  239. ) -> list[str]:
  240. """Construct the docker command using the image and docker args.
  241. Arguments:
  242. image: a Docker image to be run
  243. env_vars: a dictionary of environment variables for the command
  244. entry_cmd: the entry point command to run
  245. docker_args: a dictionary of additional docker args for the command
  246. """
  247. docker_path = "docker"
  248. cmd: list[Any] = [docker_path, "run", "--rm"]
  249. # hacky handling of env vars, needs to be improved
  250. for env_key, env_value in env_vars.items():
  251. cmd += ["-e", f"{shlex.quote(env_key)}={shlex.quote(env_value)}"]
  252. if docker_args:
  253. for name, value in docker_args.items():
  254. if len(name) == 1:
  255. prefix = "-" + shlex.quote(name)
  256. else:
  257. prefix = "--" + shlex.quote(name)
  258. if isinstance(value, list):
  259. for v in value:
  260. cmd += [prefix, shlex.quote(str(v))]
  261. elif isinstance(value, bool) and value:
  262. cmd += [prefix]
  263. else:
  264. cmd += [prefix, shlex.quote(str(value))]
  265. if entry_cmd:
  266. cmd += ["--entrypoint", entry_cmd[0]]
  267. cmd += [shlex.quote(image)]
  268. if entry_cmd and len(entry_cmd) > 1:
  269. cmd += entry_cmd[1:]
  270. if additional_args:
  271. cmd += additional_args
  272. return cmd
  273. def join(split_command: list[str]) -> str:
  274. """Return a shell-escaped string from *split_command*."""
  275. return " ".join(shlex.quote(arg) for arg in split_command)