utils.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. from __future__ import annotations
  2. import json
  3. import os
  4. import re
  5. from typing import TYPE_CHECKING, Any
  6. import wandb
  7. from wandb import util
  8. from wandb.sdk.launch.errors import LaunchError
  9. if TYPE_CHECKING:
  10. from wandb.apis.public import Api as PublicApi
  11. DEFAULT_SWEEP_COMMAND: list[str] = [
  12. "${env}",
  13. "${interpreter}",
  14. "${program}",
  15. "${args}",
  16. ]
  17. SWEEP_COMMAND_ENV_VAR_REGEX = re.compile(r"\$\{envvar\:([A-Z0-9_]*)\}")
  18. def parse_sweep_id(parts_dict: dict) -> str | None:
  19. """In place parse sweep path from parts dict.
  20. Arguments:
  21. parts_dict (dict): dict(entity=,project=,name=). Modifies dict inplace.
  22. Returns:
  23. None or str if there is an error
  24. """
  25. entity = None
  26. project = None
  27. sweep_id = parts_dict.get("name")
  28. if not isinstance(sweep_id, str):
  29. return "Expected string sweep_id"
  30. sweep_split = sweep_id.split("/")
  31. if len(sweep_split) == 1:
  32. pass
  33. elif len(sweep_split) == 2:
  34. split_project, sweep_id = sweep_split
  35. project = split_project or project
  36. elif len(sweep_split) == 3:
  37. split_entity, split_project, sweep_id = sweep_split
  38. project = split_project or project
  39. entity = split_entity or entity
  40. else:
  41. return (
  42. "Expected sweep_id in form of sweep, project/sweep, or entity/project/sweep"
  43. )
  44. parts_dict.update(dict(name=sweep_id, project=project, entity=entity))
  45. return None
  46. def sweep_config_err_text_from_jsonschema_violations(violations: list[str]) -> str:
  47. """Consolidate schema violation strings from wandb/sweeps into a single string.
  48. Parameters
  49. ----------
  50. violations: list of str
  51. The warnings to render.
  52. Returns:
  53. -------
  54. violation: str
  55. The consolidated violation text.
  56. """
  57. violation_base = (
  58. "Malformed sweep config detected! This may cause your sweep to behave in unexpected ways.\n"
  59. "To avoid this, please fix the sweep config schema violations below:"
  60. )
  61. for i, warning in enumerate(violations):
  62. violations[i] = f" Violation {i + 1}. {warning}"
  63. violation = "\n".join([violation_base] + violations)
  64. return violation
  65. def handle_sweep_config_violations(warnings: list[str]) -> None:
  66. """Echo sweep config schema violation warnings from Gorilla to the terminal.
  67. Parameters
  68. ----------
  69. warnings: list of str
  70. The warnings to render.
  71. """
  72. warning = sweep_config_err_text_from_jsonschema_violations(warnings)
  73. if len(warnings) > 0:
  74. wandb.termwarn(warning)
  75. def load_sweep_config(sweep_config_path: str) -> dict[str, Any] | None:
  76. """Load a sweep yaml from path."""
  77. import yaml
  78. try:
  79. yaml_file = open(sweep_config_path)
  80. except OSError:
  81. wandb.termerror(f"Couldn't open sweep file: {sweep_config_path}")
  82. return None
  83. try:
  84. config: dict[str, Any] | None = yaml.safe_load(yaml_file)
  85. except yaml.YAMLError as err:
  86. wandb.termerror(f"Error in configuration file: {err}")
  87. return None
  88. if not config:
  89. wandb.termerror("Configuration file is empty")
  90. return None
  91. return config
  92. def load_launch_sweep_config(config: str | None) -> Any:
  93. if not config:
  94. return {}
  95. parsed_config = util.load_json_yaml_dict(config)
  96. if parsed_config is None:
  97. raise LaunchError(f"Could not load config from {config}. Check formatting")
  98. return parsed_config
  99. def construct_scheduler_args(
  100. sweep_config: dict[str, Any],
  101. queue: str,
  102. project: str,
  103. author: str | None = None,
  104. return_job: bool = False,
  105. ) -> list[str] | dict[str, str] | None:
  106. """Construct sweep scheduler args.
  107. logs error and returns None if misconfigured,
  108. otherwise returns args as a dict if is_job else a list of strings.
  109. """
  110. job = sweep_config.get("job")
  111. image_uri = sweep_config.get("image_uri")
  112. if not job and not image_uri: # don't allow empty string
  113. wandb.termerror(
  114. "No 'job' nor 'image_uri' top-level key found in sweep config, exactly one is required for a launch-sweep"
  115. )
  116. return None
  117. elif job and image_uri:
  118. wandb.termerror(
  119. "Sweep config has both 'job' and 'image_uri' but a launch-sweep can use only one"
  120. )
  121. return None
  122. # if scheduler is a job, return args as dict
  123. if return_job:
  124. args_dict: dict[str, str] = {
  125. "sweep_id": "WANDB_SWEEP_ID",
  126. "queue": queue,
  127. "project": project,
  128. }
  129. if job:
  130. args_dict["job"] = job
  131. elif image_uri:
  132. args_dict["image_uri"] = image_uri
  133. if author:
  134. args_dict["author"] = author
  135. return args_dict
  136. # scheduler uses cli commands, pass args as param list
  137. args = [
  138. "--queue",
  139. f"{queue!r}",
  140. "--project",
  141. f"{project!r}",
  142. ]
  143. if author:
  144. args += [
  145. "--author",
  146. f"{author!r}",
  147. ]
  148. if job:
  149. args += [
  150. "--job",
  151. f"{job!r}",
  152. ]
  153. elif image_uri:
  154. args += ["--image_uri", image_uri]
  155. return args
  156. def create_sweep_command(command: list | None = None) -> list:
  157. """Return sweep command, filling in environment variable macros."""
  158. # Start from default sweep command
  159. command = command or DEFAULT_SWEEP_COMMAND
  160. for i, chunk in enumerate(command):
  161. # Replace environment variable macros
  162. # Search a str(chunk), but allow matches to be of any (ex: int) type
  163. if SWEEP_COMMAND_ENV_VAR_REGEX.search(str(chunk)):
  164. # Replace from backwards forwards
  165. matches = list(SWEEP_COMMAND_ENV_VAR_REGEX.finditer(chunk))
  166. for m in matches[::-1]:
  167. # Default to just leaving as is if environment variable does not exist
  168. _var: str = os.environ.get(m.group(1), m.group(1))
  169. command[i] = f"{command[i][: m.start()]}{_var}{command[i][m.end() :]}"
  170. return command
  171. def create_sweep_command_args(command: dict) -> dict[str, Any]:
  172. """Create various formats of command arguments for the agent.
  173. Raises:
  174. ValueError: improperly formatted command dict
  175. """
  176. if "args" not in command:
  177. raise ValueError(f'No "args" found in command: {command}')
  178. # four different formats of command args
  179. # (1) standard command line flags (e.g. --foo=bar)
  180. flags: list[str] = []
  181. # (2) flags without hyphens (e.g. foo=bar)
  182. flags_no_hyphens: list[str] = []
  183. # (3) flags with false booleans omitted (e.g. --foo)
  184. flags_no_booleans: list[str] = []
  185. # (4) flags as a dictionary (used for constructing a json)
  186. flags_dict: dict[str, Any] = {}
  187. # (5) flags without equals (e.g. --foo bar)
  188. args_no_equals: list[str] = []
  189. # (6) flags for hydra append config value (e.g. +foo=bar)
  190. flags_append_hydra: list[str] = []
  191. # (7) flags for hydra override config value (e.g. ++foo=bar)
  192. flags_override_hydra: list[str] = []
  193. for param, config in command["args"].items():
  194. # allow 'None' as a valid value, but error if no value is found
  195. try:
  196. _value: Any = config["value"]
  197. except KeyError:
  198. raise ValueError(f'No "value" found for command["args"]["{param}"]')
  199. _flag: str = f"{param}={_value}"
  200. flags.append("--" + _flag)
  201. flags_no_hyphens.append(_flag)
  202. args_no_equals += [f"--{param}", str(_value)]
  203. flags_append_hydra.append("+" + _flag)
  204. flags_override_hydra.append("++" + _flag)
  205. if isinstance(_value, bool):
  206. # omit flags if they are boolean and false
  207. if _value:
  208. flags_no_booleans.append("--" + param)
  209. else:
  210. flags_no_booleans.append("--" + _flag)
  211. flags_dict[param] = _value
  212. return {
  213. "args": flags,
  214. "args_no_equals": args_no_equals,
  215. "args_no_hyphens": flags_no_hyphens,
  216. "args_no_boolean_flags": flags_no_booleans,
  217. "args_json": [json.dumps(flags_dict)],
  218. "args_dict": flags_dict,
  219. "args_append_hydra": flags_append_hydra,
  220. "args_override_hydra": flags_override_hydra,
  221. }
  222. def make_launch_sweep_entrypoint(
  223. args: dict[str, Any], command: list[str] | None
  224. ) -> tuple[list[str] | None, Any]:
  225. """Use args dict from create_sweep_command_args to construct entrypoint.
  226. If replace is True, remove macros from entrypoint, fill them in with args
  227. and then return the args in separate return value.
  228. """
  229. if not command:
  230. return None, None
  231. entry_point = create_sweep_command(command)
  232. macro_args = {}
  233. for macro in args:
  234. mstr = "${" + macro + "}"
  235. if mstr in entry_point:
  236. idx = entry_point.index(mstr)
  237. # only supports 1 macro per entrypoint
  238. macro_args = args[macro]
  239. entry_point = entry_point[:idx] + entry_point[idx + 1 :]
  240. if len(entry_point) == 0:
  241. return None, macro_args
  242. return entry_point, macro_args
  243. def check_job_exists(public_api: PublicApi, job: str | None) -> bool:
  244. """Check if the job exists using the public api.
  245. Returns: True if no job is passed, or if the job exists.
  246. Returns: False if the job is misformatted or doesn't exist.
  247. """
  248. if not job:
  249. return True
  250. try:
  251. public_api.job(job)
  252. except Exception as e:
  253. wandb.termerror(f"Failed to load job. {e}")
  254. return False
  255. return True
  256. def get_previous_args(
  257. run_spec: dict[str, Any],
  258. ) -> tuple[dict[str, Any], dict[str, Any]]:
  259. """Parse through previous scheduler run_spec.
  260. returns scheduler_args and settings.
  261. """
  262. scheduler_args = (
  263. run_spec.get("overrides", {}).get("run_config", {}).get("scheduler", {})
  264. )
  265. # also pipe through top level resource setup
  266. if run_spec.get("resource"):
  267. scheduler_args["resource"] = run_spec["resource"]
  268. if run_spec.get("resource_args"):
  269. scheduler_args["resource_args"] = run_spec["resource_args"]
  270. settings = run_spec.get("overrides", {}).get("run_config", {}).get("settings", {})
  271. return scheduler_args, settings