| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326 |
- from __future__ import annotations
- import json
- import os
- import re
- from typing import TYPE_CHECKING, Any
- import wandb
- from wandb import util
- from wandb.sdk.launch.errors import LaunchError
- if TYPE_CHECKING:
- from wandb.apis.public import Api as PublicApi
- DEFAULT_SWEEP_COMMAND: list[str] = [
- "${env}",
- "${interpreter}",
- "${program}",
- "${args}",
- ]
- SWEEP_COMMAND_ENV_VAR_REGEX = re.compile(r"\$\{envvar\:([A-Z0-9_]*)\}")
- def parse_sweep_id(parts_dict: dict) -> str | None:
- """In place parse sweep path from parts dict.
- Arguments:
- parts_dict (dict): dict(entity=,project=,name=). Modifies dict inplace.
- Returns:
- None or str if there is an error
- """
- entity = None
- project = None
- sweep_id = parts_dict.get("name")
- if not isinstance(sweep_id, str):
- return "Expected string sweep_id"
- sweep_split = sweep_id.split("/")
- if len(sweep_split) == 1:
- pass
- elif len(sweep_split) == 2:
- split_project, sweep_id = sweep_split
- project = split_project or project
- elif len(sweep_split) == 3:
- split_entity, split_project, sweep_id = sweep_split
- project = split_project or project
- entity = split_entity or entity
- else:
- return (
- "Expected sweep_id in form of sweep, project/sweep, or entity/project/sweep"
- )
- parts_dict.update(dict(name=sweep_id, project=project, entity=entity))
- return None
- def sweep_config_err_text_from_jsonschema_violations(violations: list[str]) -> str:
- """Consolidate schema violation strings from wandb/sweeps into a single string.
- Parameters
- ----------
- violations: list of str
- The warnings to render.
- Returns:
- -------
- violation: str
- The consolidated violation text.
- """
- violation_base = (
- "Malformed sweep config detected! This may cause your sweep to behave in unexpected ways.\n"
- "To avoid this, please fix the sweep config schema violations below:"
- )
- for i, warning in enumerate(violations):
- violations[i] = f" Violation {i + 1}. {warning}"
- violation = "\n".join([violation_base] + violations)
- return violation
- def handle_sweep_config_violations(warnings: list[str]) -> None:
- """Echo sweep config schema violation warnings from Gorilla to the terminal.
- Parameters
- ----------
- warnings: list of str
- The warnings to render.
- """
- warning = sweep_config_err_text_from_jsonschema_violations(warnings)
- if len(warnings) > 0:
- wandb.termwarn(warning)
- def load_sweep_config(sweep_config_path: str) -> dict[str, Any] | None:
- """Load a sweep yaml from path."""
- import yaml
- try:
- yaml_file = open(sweep_config_path)
- except OSError:
- wandb.termerror(f"Couldn't open sweep file: {sweep_config_path}")
- return None
- try:
- config: dict[str, Any] | None = yaml.safe_load(yaml_file)
- except yaml.YAMLError as err:
- wandb.termerror(f"Error in configuration file: {err}")
- return None
- if not config:
- wandb.termerror("Configuration file is empty")
- return None
- return config
- def load_launch_sweep_config(config: str | None) -> Any:
- if not config:
- return {}
- parsed_config = util.load_json_yaml_dict(config)
- if parsed_config is None:
- raise LaunchError(f"Could not load config from {config}. Check formatting")
- return parsed_config
- def construct_scheduler_args(
- sweep_config: dict[str, Any],
- queue: str,
- project: str,
- author: str | None = None,
- return_job: bool = False,
- ) -> list[str] | dict[str, str] | None:
- """Construct sweep scheduler args.
- logs error and returns None if misconfigured,
- otherwise returns args as a dict if is_job else a list of strings.
- """
- job = sweep_config.get("job")
- image_uri = sweep_config.get("image_uri")
- if not job and not image_uri: # don't allow empty string
- wandb.termerror(
- "No 'job' nor 'image_uri' top-level key found in sweep config, exactly one is required for a launch-sweep"
- )
- return None
- elif job and image_uri:
- wandb.termerror(
- "Sweep config has both 'job' and 'image_uri' but a launch-sweep can use only one"
- )
- return None
- # if scheduler is a job, return args as dict
- if return_job:
- args_dict: dict[str, str] = {
- "sweep_id": "WANDB_SWEEP_ID",
- "queue": queue,
- "project": project,
- }
- if job:
- args_dict["job"] = job
- elif image_uri:
- args_dict["image_uri"] = image_uri
- if author:
- args_dict["author"] = author
- return args_dict
- # scheduler uses cli commands, pass args as param list
- args = [
- "--queue",
- f"{queue!r}",
- "--project",
- f"{project!r}",
- ]
- if author:
- args += [
- "--author",
- f"{author!r}",
- ]
- if job:
- args += [
- "--job",
- f"{job!r}",
- ]
- elif image_uri:
- args += ["--image_uri", image_uri]
- return args
- def create_sweep_command(command: list | None = None) -> list:
- """Return sweep command, filling in environment variable macros."""
- # Start from default sweep command
- command = command or DEFAULT_SWEEP_COMMAND
- for i, chunk in enumerate(command):
- # Replace environment variable macros
- # Search a str(chunk), but allow matches to be of any (ex: int) type
- if SWEEP_COMMAND_ENV_VAR_REGEX.search(str(chunk)):
- # Replace from backwards forwards
- matches = list(SWEEP_COMMAND_ENV_VAR_REGEX.finditer(chunk))
- for m in matches[::-1]:
- # Default to just leaving as is if environment variable does not exist
- _var: str = os.environ.get(m.group(1), m.group(1))
- command[i] = f"{command[i][: m.start()]}{_var}{command[i][m.end() :]}"
- return command
- def create_sweep_command_args(command: dict) -> dict[str, Any]:
- """Create various formats of command arguments for the agent.
- Raises:
- ValueError: improperly formatted command dict
- """
- if "args" not in command:
- raise ValueError(f'No "args" found in command: {command}')
- # four different formats of command args
- # (1) standard command line flags (e.g. --foo=bar)
- flags: list[str] = []
- # (2) flags without hyphens (e.g. foo=bar)
- flags_no_hyphens: list[str] = []
- # (3) flags with false booleans omitted (e.g. --foo)
- flags_no_booleans: list[str] = []
- # (4) flags as a dictionary (used for constructing a json)
- flags_dict: dict[str, Any] = {}
- # (5) flags without equals (e.g. --foo bar)
- args_no_equals: list[str] = []
- # (6) flags for hydra append config value (e.g. +foo=bar)
- flags_append_hydra: list[str] = []
- # (7) flags for hydra override config value (e.g. ++foo=bar)
- flags_override_hydra: list[str] = []
- for param, config in command["args"].items():
- # allow 'None' as a valid value, but error if no value is found
- try:
- _value: Any = config["value"]
- except KeyError:
- raise ValueError(f'No "value" found for command["args"]["{param}"]')
- _flag: str = f"{param}={_value}"
- flags.append("--" + _flag)
- flags_no_hyphens.append(_flag)
- args_no_equals += [f"--{param}", str(_value)]
- flags_append_hydra.append("+" + _flag)
- flags_override_hydra.append("++" + _flag)
- if isinstance(_value, bool):
- # omit flags if they are boolean and false
- if _value:
- flags_no_booleans.append("--" + param)
- else:
- flags_no_booleans.append("--" + _flag)
- flags_dict[param] = _value
- return {
- "args": flags,
- "args_no_equals": args_no_equals,
- "args_no_hyphens": flags_no_hyphens,
- "args_no_boolean_flags": flags_no_booleans,
- "args_json": [json.dumps(flags_dict)],
- "args_dict": flags_dict,
- "args_append_hydra": flags_append_hydra,
- "args_override_hydra": flags_override_hydra,
- }
- def make_launch_sweep_entrypoint(
- args: dict[str, Any], command: list[str] | None
- ) -> tuple[list[str] | None, Any]:
- """Use args dict from create_sweep_command_args to construct entrypoint.
- If replace is True, remove macros from entrypoint, fill them in with args
- and then return the args in separate return value.
- """
- if not command:
- return None, None
- entry_point = create_sweep_command(command)
- macro_args = {}
- for macro in args:
- mstr = "${" + macro + "}"
- if mstr in entry_point:
- idx = entry_point.index(mstr)
- # only supports 1 macro per entrypoint
- macro_args = args[macro]
- entry_point = entry_point[:idx] + entry_point[idx + 1 :]
- if len(entry_point) == 0:
- return None, macro_args
- return entry_point, macro_args
- def check_job_exists(public_api: PublicApi, job: str | None) -> bool:
- """Check if the job exists using the public api.
- Returns: True if no job is passed, or if the job exists.
- Returns: False if the job is misformatted or doesn't exist.
- """
- if not job:
- return True
- try:
- public_api.job(job)
- except Exception as e:
- wandb.termerror(f"Failed to load job. {e}")
- return False
- return True
- def get_previous_args(
- run_spec: dict[str, Any],
- ) -> tuple[dict[str, Any], dict[str, Any]]:
- """Parse through previous scheduler run_spec.
- returns scheduler_args and settings.
- """
- scheduler_args = (
- run_spec.get("overrides", {}).get("run_config", {}).get("scheduler", {})
- )
- # also pipe through top level resource setup
- if run_spec.get("resource"):
- scheduler_args["resource"] = run_spec["resource"]
- if run_spec.get("resource_args"):
- scheduler_args["resource_args"] = run_spec["resource_args"]
- settings = run_spec.get("overrides", {}).get("run_config", {}).get("settings", {})
- return scheduler_args, settings
|