_launch_add.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. from __future__ import annotations
  2. import asyncio
  3. import pprint
  4. from typing import Any
  5. import wandb
  6. import wandb.apis.public as public
  7. from wandb.apis.internal import Api
  8. from wandb.errors import CommError
  9. from wandb.sdk.launch.builder.build import build_image_from_project
  10. from wandb.sdk.launch.errors import LaunchError
  11. from wandb.sdk.launch.utils import (
  12. LAUNCH_DEFAULT_PROJECT,
  13. LOG_PREFIX,
  14. construct_launch_spec,
  15. validate_launch_spec_source,
  16. )
  17. from ._project_spec import LaunchProject
  18. def push_to_queue(
  19. api: Api,
  20. queue_name: str,
  21. launch_spec: dict[str, Any],
  22. template_variables: dict | None,
  23. project_queue: str,
  24. priority: int | None = None,
  25. ) -> Any:
  26. return api.push_to_run_queue(
  27. queue_name, launch_spec, template_variables, project_queue, priority
  28. )
  29. def launch_add(
  30. uri: str | None = None,
  31. job: str | None = None,
  32. config: dict[str, Any] | None = None,
  33. template_variables: dict[str, float | int | str] | None = None,
  34. project: str | None = None,
  35. entity: str | None = None,
  36. queue_name: str | None = None,
  37. resource: str | None = None,
  38. entry_point: list[str] | None = None,
  39. name: str | None = None,
  40. version: str | None = None,
  41. docker_image: str | None = None,
  42. project_queue: str | None = None,
  43. resource_args: dict[str, Any] | None = None,
  44. run_id: str | None = None,
  45. build: bool | None = False,
  46. repository: str | None = None,
  47. sweep_id: str | None = None,
  48. author: str | None = None,
  49. priority: int | None = None,
  50. ) -> public.QueuedRun:
  51. """Enqueue a W&B launch experiment. With either a source uri, job or docker_image.
  52. Arguments:
  53. uri: URI of experiment to run. A wandb run uri or a Git repository URI.
  54. job: string reference to a wandb.Job eg: wandb/test/my-job:latest
  55. config: A dictionary containing the configuration for the run. May also contain
  56. resource specific arguments under the key "resource_args"
  57. template_variables: A dictionary containing values of template variables for a run queue.
  58. Expected format of `{"VAR_NAME": VAR_VALUE}`
  59. project: Target project to send launched run to
  60. entity: Target entity to send launched run to
  61. queue: the name of the queue to enqueue the run to
  62. priority: the priority level of the job, where 1 is the highest priority
  63. resource: Execution backend for the run: W&B provides built-in support for "local-container" backend
  64. entry_point: Entry point to run within the project. Defaults to using the entry point used
  65. in the original run for wandb URIs, or main.py for git repository URIs.
  66. name: Name run under which to launch the run.
  67. version: For Git-based projects, either a commit hash or a branch name.
  68. docker_image: The name of the docker image to use for the run.
  69. resource_args: Resource related arguments for launching runs onto a remote backend.
  70. Will be stored on the constructed launch config under ``resource_args``.
  71. run_id: optional string indicating the id of the launched run
  72. build: optional flag defaulting to false, requires queue to be set
  73. if build, an image is created, creates a job artifact, pushes a reference
  74. to that job artifact to queue
  75. repository: optional string to control the name of the remote repository, used when
  76. pushing images to a registry
  77. project_queue: optional string to control the name of the project for the queue. Primarily used
  78. for back compatibility with project scoped queues
  79. Example:
  80. ```python
  81. from wandb.sdk.launch import launch_add
  82. project_uri = "https://github.com/wandb/examples"
  83. params = {"alpha": 0.5, "l1_ratio": 0.01}
  84. # Run W&B project and create a reproducible docker environment
  85. # on a local host
  86. api = wandb.apis.internal.Api()
  87. launch_add(uri=project_uri, parameters=params)
  88. ```
  89. Returns:
  90. an instance of`wandb.api.public.QueuedRun` which gives information about the
  91. queued run, or if `wait_until_started` or `wait_until_finished` are called, gives access
  92. to the underlying Run information.
  93. Raises:
  94. `wandb.exceptions.LaunchError` if unsuccessful
  95. """
  96. api = Api()
  97. return _launch_add(
  98. api,
  99. job,
  100. config,
  101. template_variables,
  102. project,
  103. entity,
  104. queue_name,
  105. resource,
  106. entry_point,
  107. name,
  108. version,
  109. docker_image,
  110. project_queue,
  111. resource_args,
  112. run_id=run_id,
  113. build=build,
  114. repository=repository,
  115. sweep_id=sweep_id,
  116. author=author,
  117. priority=priority,
  118. )
  119. def _launch_add(
  120. api: Api,
  121. job: str | None,
  122. config: dict[str, Any] | None,
  123. template_variables: dict | None,
  124. project: str | None,
  125. entity: str | None,
  126. queue_name: str | None,
  127. resource: str | None,
  128. entry_point: list[str] | None,
  129. name: str | None,
  130. version: str | None,
  131. docker_image: str | None,
  132. project_queue: str | None,
  133. resource_args: dict[str, Any] | None = None,
  134. run_id: str | None = None,
  135. build: bool | None = False,
  136. repository: str | None = None,
  137. sweep_id: str | None = None,
  138. author: str | None = None,
  139. priority: int | None = None,
  140. ) -> public.QueuedRun:
  141. launch_spec = construct_launch_spec(
  142. None,
  143. job,
  144. api,
  145. name,
  146. project,
  147. entity,
  148. docker_image,
  149. resource,
  150. entry_point,
  151. version,
  152. resource_args,
  153. config,
  154. run_id,
  155. repository,
  156. author,
  157. sweep_id,
  158. )
  159. if build:
  160. if resource == "local-process":
  161. raise LaunchError(
  162. "Cannot build a docker image for the resource: local-process"
  163. )
  164. if launch_spec.get("job") is not None:
  165. wandb.termwarn("Build doesn't support setting a job. Overwriting job.")
  166. launch_spec["job"] = None
  167. launch_project = LaunchProject.from_spec(launch_spec, api)
  168. docker_image_uri = asyncio.run(
  169. build_image_from_project(launch_project, api, config or {})
  170. )
  171. run = wandb.run or wandb.init(
  172. project=launch_spec["project"],
  173. entity=launch_spec["entity"],
  174. job_type="launch_job",
  175. )
  176. job_artifact = run._log_job_artifact_with_image( # type: ignore
  177. docker_image_uri, launch_project.override_args
  178. )
  179. job_name = job_artifact.wait().name
  180. job = f"{launch_spec['entity']}/{launch_spec['project']}/{job_name}"
  181. launch_spec["job"] = job
  182. launch_spec["uri"] = None # Remove given URI --> now in job
  183. if queue_name is None:
  184. queue_name = "default"
  185. if project_queue is None:
  186. project_queue = LAUNCH_DEFAULT_PROJECT
  187. spec_template_vars = launch_spec.get("template_variables")
  188. if isinstance(spec_template_vars, dict):
  189. launch_spec.pop("template_variables")
  190. if template_variables is None:
  191. template_variables = spec_template_vars
  192. else:
  193. template_variables = {
  194. **spec_template_vars,
  195. **template_variables,
  196. }
  197. validate_launch_spec_source(launch_spec)
  198. res = push_to_queue(
  199. api, queue_name, launch_spec, template_variables, project_queue, priority
  200. )
  201. if res is None or "runQueueItemId" not in res:
  202. raise LaunchError("Error adding run to queue")
  203. updated_spec = res.get("runSpec")
  204. if updated_spec:
  205. if updated_spec.get("resource_args"):
  206. launch_spec["resource_args"] = updated_spec.get("resource_args")
  207. if updated_spec.get("resource"):
  208. launch_spec["resource"] = updated_spec.get("resource")
  209. if project_queue == LAUNCH_DEFAULT_PROJECT:
  210. wandb.termlog(f"{LOG_PREFIX}Added run to queue {queue_name}.")
  211. else:
  212. wandb.termlog(f"{LOG_PREFIX}Added run to queue {project_queue}/{queue_name}.")
  213. wandb.termlog(f"{LOG_PREFIX}Launch spec:\n{pprint.pformat(launch_spec)}\n")
  214. public_api = public.Api()
  215. if job is not None:
  216. try:
  217. public_api._artifact(job, type="job")
  218. except (ValueError, CommError) as e:
  219. raise LaunchError(f"Unable to fetch job with name {job}: {e}")
  220. queued_run = public_api.queued_run(
  221. launch_spec["entity"],
  222. launch_spec["project"],
  223. queue_name,
  224. res["runQueueItemId"],
  225. project_queue,
  226. priority,
  227. )
  228. return queued_run # type: ignore