_project_spec.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. """Convert launch arguments into a runnable wandb launch script.
  2. Arguments can come from a launch spec or call to wandb launch.
  3. """
  4. from __future__ import annotations
  5. import enum
  6. import json
  7. import logging
  8. import os
  9. import shlex
  10. import shutil
  11. import tempfile
  12. from copy import deepcopy
  13. from typing import TYPE_CHECKING, Any, cast
  14. import wandb
  15. from wandb.apis.internal import Api
  16. from wandb.errors import CommError
  17. from wandb.sdk.launch.utils import get_entrypoint_file
  18. from wandb.sdk.lib.runid import generate_id
  19. from .errors import LaunchError
  20. from .utils import CODE_MOUNT_DIR, LOG_PREFIX, recursive_macro_sub
  21. if TYPE_CHECKING:
  22. from wandb.sdk.artifacts.artifact import Artifact
  23. _logger = logging.getLogger(__name__)
  24. # need to make user root for sagemaker, so users have access to /opt/ml directories
  25. # that let users create artifacts and access input data
  26. RESOURCE_UID_MAP = {"local": 1000, "sagemaker": 0}
  27. IMAGE_TAG_MAX_LENGTH = 32
  28. class LaunchSource(enum.IntEnum):
  29. """Enumeration of possible sources for a launch project.
  30. Attributes:
  31. DOCKER: Source is a Docker image. This can happen if a user runs
  32. `wandb launch -d <docker-image>`.
  33. JOB: Source is a job. This is standard case.
  34. SCHEDULER: Source is a wandb sweep scheduler command.
  35. """
  36. DOCKER = 1
  37. JOB = 2
  38. SCHEDULER = 3
  39. class LaunchProject:
  40. """A launch project specification.
  41. The LaunchProject is initialized from a raw launch spec an internal API
  42. object. The project encapsulates logic for taking a launch spec and converting
  43. it into the executable code.
  44. The LaunchProject needs to ultimately produce a full container spec for
  45. execution in docker, k8s, sagemaker, or vertex. This container spec includes:
  46. - container image uri
  47. - environment variables for configuring wandb etc.
  48. - entrypoint command and arguments
  49. - additional arguments specific to the target resource (e.g. instance type, node selector)
  50. This class is stateful and certain methods can only be called after
  51. `LaunchProject.fetch_and_validate_project()` has been called.
  52. Notes on the entrypoint:
  53. - The entrypoint is the command that will be run inside the container.
  54. - The LaunchProject stores two entrypoints
  55. - The job entrypoint is the entrypoint specified in the job's config.
  56. - The override entrypoint is the entrypoint specified in the launch spec.
  57. - The override entrypoint takes precedence over the job entrypoint.
  58. """
  59. # This init is way to long, and there are too many attributes on this sucker.
  60. def __init__(
  61. self,
  62. uri: str | None,
  63. job: str | None,
  64. api: Api,
  65. launch_spec: dict[str, Any],
  66. target_entity: str,
  67. target_project: str,
  68. name: str | None,
  69. docker_config: dict[str, Any],
  70. git_info: dict[str, str],
  71. overrides: dict[str, Any],
  72. resource: str,
  73. resource_args: dict[str, Any],
  74. run_id: str | None,
  75. sweep_id: str | None = None,
  76. ):
  77. self.uri = uri
  78. self.job = job
  79. if job is not None:
  80. wandb.termlog(f"{LOG_PREFIX}Launching job: {job}")
  81. self._job_artifact: Artifact | None = None
  82. self.api = api
  83. self.launch_spec = launch_spec
  84. self.target_entity = target_entity
  85. self.target_project = target_project.lower()
  86. self.name = name # TODO: replace with run_id
  87. # the builder key can be passed in through the resource args
  88. # but these resource_args are then passed to the appropriate
  89. # runner, so we need to pop the builder key out
  90. resource_args_copy = deepcopy(resource_args)
  91. self._resource_args_build = resource_args_copy.get(resource, {}).pop(
  92. "builder", {}
  93. )
  94. self.resource = resource
  95. self.resource_args = resource_args_copy
  96. self.sweep_id = sweep_id
  97. self.author = launch_spec.get("author")
  98. self.python_version: str | None = launch_spec.get("python_version")
  99. self._job_dockerfile: str | None = None
  100. self._job_build_context: str | None = None
  101. self._job_base_image: str | None = None
  102. self.accelerator_base_image: str | None = self._resource_args_build.get(
  103. "accelerator", {}
  104. ).get("base_image") or self._resource_args_build.get("cuda", {}).get(
  105. "base_image"
  106. )
  107. self.docker_image: str | None = docker_config.get(
  108. "docker_image"
  109. ) or launch_spec.get("image_uri") # type: ignore [assignment]
  110. self.docker_user_id = docker_config.get("user_id", 1000)
  111. self._entry_point: EntryPoint | None = (
  112. None # todo: keep multiple entrypoint support?
  113. )
  114. self.init_overrides(overrides)
  115. self.init_source()
  116. self.init_git(git_info)
  117. self.deps_type: str | None = None
  118. self._runtime: str | None = None
  119. self.run_id = run_id or generate_id()
  120. self._queue_name: str | None = None
  121. self._queue_entity: str | None = None
  122. self._run_queue_item_id: str | None = None
  123. self._job_source_type: str | None = None
  124. self._job_source_info: dict[str, Any] = {}
  125. self._auto_default_base_image: bool = False
  126. def init_source(self) -> None:
  127. if self.docker_image is not None:
  128. self.source = LaunchSource.DOCKER
  129. self.project_dir = None
  130. elif self.job is not None:
  131. self.source = LaunchSource.JOB
  132. self.project_dir = tempfile.mkdtemp()
  133. elif self.uri and self.uri.startswith("placeholder"):
  134. self.source = LaunchSource.SCHEDULER
  135. self.project_dir = os.getcwd()
  136. self._entry_point = self.override_entrypoint
  137. def change_project_dir(self, new_dir: str) -> None:
  138. """Change the project directory to a new directory."""
  139. # Copy the contents of the old project dir to the new project dir.
  140. old_dir = self.project_dir
  141. if old_dir is not None:
  142. shutil.copytree(
  143. old_dir,
  144. new_dir,
  145. symlinks=True,
  146. dirs_exist_ok=True,
  147. ignore=shutil.ignore_patterns("fsmonitor--daemon.ipc", ".git"),
  148. )
  149. shutil.rmtree(old_dir)
  150. self.project_dir = new_dir
  151. def init_git(self, git_info: dict[str, str]) -> None:
  152. self.git_version = git_info.get("version")
  153. self.git_repo = git_info.get("repo")
  154. def init_overrides(self, overrides: dict[str, Any]) -> None:
  155. """Initialize override attributes for a launch project."""
  156. self.overrides = overrides
  157. self.override_args: list[str] = overrides.get("args", [])
  158. self.override_config: dict[str, Any] = overrides.get("run_config", {})
  159. self.override_artifacts: dict[str, Any] = overrides.get("artifacts", {})
  160. self.override_files: dict[str, Any] = overrides.get("files", {})
  161. self.override_entrypoint: EntryPoint | None = None
  162. self.override_dockerfile: str | None = overrides.get("dockerfile")
  163. override_entrypoint = overrides.get("entry_point")
  164. if override_entrypoint:
  165. _logger.info("Adding override entry point")
  166. self.override_entrypoint = EntryPoint(
  167. name=get_entrypoint_file(override_entrypoint),
  168. command=override_entrypoint,
  169. )
  170. override_working_dir = overrides.get("working_dir")
  171. self.resolved_working_dir: str = (
  172. f"{CODE_MOUNT_DIR}/{override_working_dir}"
  173. if override_working_dir
  174. else CODE_MOUNT_DIR
  175. )
  176. def __repr__(self) -> str:
  177. """String representation of LaunchProject."""
  178. if self.source == LaunchSource.JOB:
  179. return f"{self.job}"
  180. return f"{self.uri}"
  181. @classmethod
  182. def from_spec(cls, launch_spec: dict[str, Any], api: Api) -> LaunchProject:
  183. """Constructs a LaunchProject instance using a launch spec.
  184. Arguments:
  185. launch_spec: Dictionary representation of launch spec
  186. api: Instance of wandb.apis.internal Api
  187. Returns:
  188. An initialized `LaunchProject` object
  189. """
  190. name: str | None = None
  191. if launch_spec.get("name"):
  192. name = launch_spec["name"]
  193. return LaunchProject(
  194. launch_spec.get("uri"),
  195. launch_spec.get("job"),
  196. api,
  197. launch_spec,
  198. launch_spec["entity"],
  199. launch_spec["project"],
  200. name,
  201. launch_spec.get("docker", {}),
  202. launch_spec.get("git", {}),
  203. launch_spec.get("overrides", {}),
  204. launch_spec.get("resource"), # type: ignore [arg-type]
  205. launch_spec.get("resource_args", {}),
  206. launch_spec.get("run_id"),
  207. launch_spec.get("sweep_id", {}),
  208. )
  209. @property
  210. def job_dockerfile(self) -> str | None:
  211. return self._job_dockerfile
  212. @property
  213. def job_build_context(self) -> str | None:
  214. return self._job_build_context
  215. @property
  216. def job_base_image(self) -> str | None:
  217. return self._job_base_image
  218. def set_job_dockerfile(self, dockerfile: str) -> None:
  219. self._job_dockerfile = dockerfile
  220. def set_job_build_context(self, build_context: str) -> None:
  221. self._job_build_context = build_context
  222. def set_job_base_image(self, base_image: str) -> None:
  223. self._job_base_image = base_image
  224. @property
  225. def image_name(self) -> str:
  226. if self.job_base_image is not None:
  227. return self.job_base_image
  228. if self.docker_image is not None:
  229. return self.docker_image
  230. elif self.uri is not None:
  231. cleaned_uri = self.uri.replace("https://", "/")
  232. first_sep = cleaned_uri.find("/")
  233. shortened_uri = cleaned_uri[first_sep:]
  234. return wandb.util.make_docker_image_name_safe(shortened_uri)
  235. else:
  236. # this will always pass since one of these 3 is required
  237. assert self.job is not None
  238. return wandb.util.make_docker_image_name_safe(self.job.split(":")[0])
  239. @property
  240. def queue_name(self) -> str | None:
  241. return self._queue_name
  242. @queue_name.setter
  243. def queue_name(self, value: str) -> None:
  244. self._queue_name = value
  245. @property
  246. def queue_entity(self) -> str | None:
  247. return self._queue_entity
  248. @queue_entity.setter
  249. def queue_entity(self, value: str) -> None:
  250. self._queue_entity = value
  251. @property
  252. def run_queue_item_id(self) -> str | None:
  253. return self._run_queue_item_id
  254. @run_queue_item_id.setter
  255. def run_queue_item_id(self, value: str) -> None:
  256. self._run_queue_item_id = value
  257. @property
  258. def job_source_type(self) -> str | None:
  259. return self._job_source_type
  260. def set_job_source_type(self, source_type: str) -> None:
  261. self._job_source_type = source_type
  262. @property
  263. def job_source_info(self) -> dict[str, Any]:
  264. return self._job_source_info
  265. def set_job_source_info(self, source_info: dict[str, Any]) -> None:
  266. self._job_source_info = source_info
  267. def fill_macros(self, image: str) -> dict[str, Any]:
  268. """Substitute values for macros in resource arguments.
  269. Certain macros can be used in resource args. These macros allow the
  270. user to set resource args dynamically in the context of the
  271. run being launched. The macros are given in the ${macro} format. The
  272. following macros are currently supported:
  273. ${project_name} - the name of the project the run is being launched to.
  274. ${entity_name} - the owner of the project the run being launched to.
  275. ${run_id} - the id of the run being launched.
  276. ${run_name} - the name of the run that is launching.
  277. ${image_uri} - the URI of the container image for this run.
  278. Additionally, you may use ${<ENV-VAR-NAME>} to refer to the value of any
  279. environment variables that you plan to set in the environment of any
  280. agents that will receive these resource args.
  281. Calling this method will overwrite the contents of self.resource_args
  282. with the substituted values.
  283. Args:
  284. image (str): The image name to fill in for ${wandb-image}.
  285. Returns:
  286. Dict[str, Any]: The resource args with all macros filled in.
  287. """
  288. update_dict = {
  289. "project_name": self.target_project,
  290. "entity_name": self.target_entity,
  291. "run_id": self.run_id,
  292. "run_name": self.name,
  293. "image_uri": image,
  294. "author": self.author,
  295. }
  296. update_dict.update(os.environ)
  297. result = recursive_macro_sub(self.resource_args, update_dict)
  298. # recursive_macro_sub given a dict returns a dict with the same keys
  299. # but with other input types behaves differently. The cast is for mypy.
  300. return cast(dict[str, Any], result)
  301. def build_required(self) -> bool:
  302. """Checks the source to see if a build is required."""
  303. if self.job_base_image is not None:
  304. return False
  305. return self.source != LaunchSource.JOB
  306. @property
  307. def docker_image(self) -> str | None:
  308. """Returns the Docker image associated with this LaunchProject.
  309. This will only be set if an image_uri is being run outside a job.
  310. Returns:
  311. Optional[str]: The Docker image or None if not specified.
  312. """
  313. if self._docker_image:
  314. return self._docker_image
  315. return None
  316. @docker_image.setter
  317. def docker_image(self, value: str) -> None:
  318. """Sets the Docker image for the project.
  319. Args:
  320. value (str): The Docker image to set.
  321. Returns:
  322. None
  323. """
  324. self._docker_image = value
  325. self._ensure_not_docker_image_and_local_process()
  326. def get_job_entry_point(self) -> EntryPoint | None:
  327. """Returns the job entrypoint for the project."""
  328. # assuming project only has 1 entry point, pull that out
  329. # tmp fn until we figure out if we want to support multiple entry points or not
  330. if not self._entry_point:
  331. if not self.docker_image and not self.job_base_image:
  332. raise LaunchError(
  333. "Project must have at least one entry point unless docker image is specified."
  334. )
  335. return None
  336. return self._entry_point
  337. def set_job_entry_point(self, command: list[str]) -> EntryPoint:
  338. """Set job entrypoint for the project."""
  339. assert self._entry_point is None, (
  340. "Cannot set entry point twice. Use LaunchProject.override_entrypoint"
  341. )
  342. new_entrypoint = EntryPoint(name=command[-1], command=command)
  343. self._entry_point = new_entrypoint
  344. return new_entrypoint
  345. def fetch_and_validate_project(self) -> None:
  346. """Fetches a project into a local directory, adds the config values to the directory, and validates the first entrypoint for the project.
  347. Arguments:
  348. launch_project: LaunchProject to fetch and validate.
  349. api: Instance of wandb.apis.internal Api
  350. Returns:
  351. A validated `LaunchProject` object.
  352. """
  353. if self.source == LaunchSource.DOCKER:
  354. return
  355. elif self.source == LaunchSource.JOB:
  356. self._fetch_job()
  357. assert self.project_dir is not None
  358. # Let's make sure we document this very clearly.
  359. def get_image_source_string(self) -> str:
  360. """Returns a unique string identifying the source of an image."""
  361. if self.source == LaunchSource.JOB:
  362. assert self._job_artifact is not None
  363. return f"{self._job_artifact.name}:v{self._job_artifact.version}"
  364. elif self.source == LaunchSource.DOCKER:
  365. assert isinstance(self.docker_image, str)
  366. return self.docker_image
  367. else:
  368. raise LaunchError(
  369. "Unknown source type when determining image source string"
  370. )
  371. def _ensure_not_docker_image_and_local_process(self) -> None:
  372. """Ensure that docker image is not specified with local-process resource runner.
  373. Raises:
  374. LaunchError: If docker image is specified with local-process resource runner.
  375. """
  376. if self.docker_image is not None and self.resource == "local-process":
  377. raise LaunchError(
  378. "Cannot specify docker image with local-process resource runner"
  379. )
  380. def _fetch_job(self) -> None:
  381. """Fetches the job details from the public API and configures the launch project.
  382. Raises:
  383. LaunchError: If there is an error accessing the job.
  384. """
  385. public_api = wandb.apis.public.Api()
  386. job_dir = tempfile.mkdtemp()
  387. try:
  388. job = public_api.job(self.job, path=job_dir)
  389. except CommError as e:
  390. msg = e.message
  391. raise LaunchError(
  392. f"Error accessing job {self.job}: {msg} on {public_api.settings.get('base_url')}"
  393. )
  394. job.configure_launch_project(self) # Why is this a method of the job?
  395. self._job_artifact = job._job_artifact
  396. def get_env_vars_dict(self, api: Api, max_env_length: int) -> dict[str, str]:
  397. """Generate environment variables for the project.
  398. Arguments:
  399. launch_project: LaunchProject to generate environment variables for.
  400. Returns:
  401. Dictionary of environment variables.
  402. """
  403. env_vars = {}
  404. env_vars["WANDB_BASE_URL"] = api.settings("base_url")
  405. override_api_key = self.launch_spec.get("_wandb_api_key")
  406. env_vars["WANDB_API_KEY"] = override_api_key or api.api_key
  407. if self.target_project:
  408. env_vars["WANDB_PROJECT"] = self.target_project
  409. env_vars["WANDB_ENTITY"] = self.target_entity
  410. env_vars["WANDB_LAUNCH"] = "True"
  411. env_vars["WANDB_RUN_ID"] = self.run_id
  412. if self.docker_image:
  413. env_vars["WANDB_DOCKER"] = self.docker_image
  414. if self.name is not None:
  415. env_vars["WANDB_NAME"] = self.name
  416. if "author" in self.launch_spec and not override_api_key:
  417. env_vars["WANDB_USERNAME"] = self.launch_spec["author"]
  418. if self.sweep_id:
  419. env_vars["WANDB_SWEEP_ID"] = self.sweep_id
  420. if self.launch_spec.get("_resume_count", 0) > 0:
  421. env_vars["WANDB_RESUME"] = "allow"
  422. if self.queue_name:
  423. env_vars[wandb.env.LAUNCH_QUEUE_NAME] = self.queue_name
  424. if self.queue_entity:
  425. env_vars[wandb.env.LAUNCH_QUEUE_ENTITY] = self.queue_entity
  426. if self.run_queue_item_id:
  427. env_vars[wandb.env.LAUNCH_TRACE_ID] = self.run_queue_item_id
  428. _inject_wandb_config_env_vars(self.override_config, env_vars, max_env_length)
  429. _inject_file_overrides_env_vars(self.override_files, env_vars, max_env_length)
  430. artifacts = {}
  431. # if we're spinning up a launch process from a job
  432. # we should tell the run to use that artifact
  433. if self.job:
  434. artifacts = {wandb.util.LAUNCH_JOB_ARTIFACT_SLOT_NAME: self.job}
  435. env_vars["WANDB_ARTIFACTS"] = json.dumps(
  436. {**artifacts, **self.override_artifacts}
  437. )
  438. return env_vars
  439. def parse_existing_requirements(self) -> str:
  440. from packaging.requirements import InvalidRequirement, Requirement
  441. requirements_line = ""
  442. assert self.project_dir is not None
  443. base_requirements = os.path.join(self.project_dir, "requirements.txt")
  444. if os.path.exists(base_requirements):
  445. include_only = set()
  446. with open(base_requirements) as f2:
  447. for line in f2:
  448. if line.strip() == "":
  449. continue
  450. try:
  451. req = Requirement(line)
  452. name = req.name.lower()
  453. include_only.add(shlex.quote(name))
  454. except InvalidRequirement:
  455. _logger.warning(
  456. "Unable to parse line %s in requirements.txt",
  457. line,
  458. exc_info=True,
  459. )
  460. continue
  461. requirements_line += "WANDB_ONLY_INCLUDE={} ".format(",".join(include_only))
  462. if "wandb" not in requirements_line:
  463. wandb.termwarn(f"{LOG_PREFIX}wandb is not present in requirements.txt.")
  464. return requirements_line
  465. class EntryPoint:
  466. """An entry point into a wandb launch specification."""
  467. def __init__(self, name: str | None, command: list[str]):
  468. self.name = name
  469. self.command = command
  470. def update_entrypoint_path(self, new_path: str) -> None:
  471. """Updates the entrypoint path to a new path."""
  472. if len(self.command) == 2 and (
  473. self.command[0].startswith("python") or self.command[0] == "bash"
  474. ):
  475. self.command[1] = new_path
  476. def _inject_wandb_config_env_vars(
  477. config: dict[str, Any], env_dict: dict[str, Any], maximum_env_length: int
  478. ) -> None:
  479. str_config = json.dumps(config)
  480. if len(str_config) <= maximum_env_length:
  481. env_dict["WANDB_CONFIG"] = str_config
  482. return
  483. chunks = [
  484. str_config[i : i + maximum_env_length]
  485. for i in range(0, len(str_config), maximum_env_length)
  486. ]
  487. config_chunks_dict = {f"WANDB_CONFIG_{i}": chunk for i, chunk in enumerate(chunks)}
  488. env_dict.update(config_chunks_dict)
  489. def _inject_file_overrides_env_vars(
  490. overrides: dict[str, Any], env_dict: dict[str, Any], maximum_env_length: int
  491. ) -> None:
  492. str_overrides = json.dumps(overrides)
  493. if len(str_overrides) <= maximum_env_length:
  494. env_dict["WANDB_LAUNCH_FILE_OVERRIDES"] = str_overrides
  495. return
  496. chunks = [
  497. str_overrides[i : i + maximum_env_length]
  498. for i in range(0, len(str_overrides), maximum_env_length)
  499. ]
  500. overrides_chunks_dict = {
  501. f"WANDB_LAUNCH_FILE_OVERRIDES_{i}": chunk for i, chunk in enumerate(chunks)
  502. }
  503. env_dict.update(overrides_chunks_dict)