from __future__ import annotations import json import os import re from typing import TYPE_CHECKING, Any import wandb from wandb import util from wandb.sdk.launch.errors import LaunchError if TYPE_CHECKING: from wandb.apis.public import Api as PublicApi DEFAULT_SWEEP_COMMAND: list[str] = [ "${env}", "${interpreter}", "${program}", "${args}", ] SWEEP_COMMAND_ENV_VAR_REGEX = re.compile(r"\$\{envvar\:([A-Z0-9_]*)\}") def parse_sweep_id(parts_dict: dict) -> str | None: """In place parse sweep path from parts dict. Arguments: parts_dict (dict): dict(entity=,project=,name=). Modifies dict inplace. Returns: None or str if there is an error """ entity = None project = None sweep_id = parts_dict.get("name") if not isinstance(sweep_id, str): return "Expected string sweep_id" sweep_split = sweep_id.split("/") if len(sweep_split) == 1: pass elif len(sweep_split) == 2: split_project, sweep_id = sweep_split project = split_project or project elif len(sweep_split) == 3: split_entity, split_project, sweep_id = sweep_split project = split_project or project entity = split_entity or entity else: return ( "Expected sweep_id in form of sweep, project/sweep, or entity/project/sweep" ) parts_dict.update(dict(name=sweep_id, project=project, entity=entity)) return None def sweep_config_err_text_from_jsonschema_violations(violations: list[str]) -> str: """Consolidate schema violation strings from wandb/sweeps into a single string. Parameters ---------- violations: list of str The warnings to render. Returns: ------- violation: str The consolidated violation text. """ violation_base = ( "Malformed sweep config detected! This may cause your sweep to behave in unexpected ways.\n" "To avoid this, please fix the sweep config schema violations below:" ) for i, warning in enumerate(violations): violations[i] = f" Violation {i + 1}. {warning}" violation = "\n".join([violation_base] + violations) return violation def handle_sweep_config_violations(warnings: list[str]) -> None: """Echo sweep config schema violation warnings from Gorilla to the terminal. Parameters ---------- warnings: list of str The warnings to render. """ warning = sweep_config_err_text_from_jsonschema_violations(warnings) if len(warnings) > 0: wandb.termwarn(warning) def load_sweep_config(sweep_config_path: str) -> dict[str, Any] | None: """Load a sweep yaml from path.""" import yaml try: yaml_file = open(sweep_config_path) except OSError: wandb.termerror(f"Couldn't open sweep file: {sweep_config_path}") return None try: config: dict[str, Any] | None = yaml.safe_load(yaml_file) except yaml.YAMLError as err: wandb.termerror(f"Error in configuration file: {err}") return None if not config: wandb.termerror("Configuration file is empty") return None return config def load_launch_sweep_config(config: str | None) -> Any: if not config: return {} parsed_config = util.load_json_yaml_dict(config) if parsed_config is None: raise LaunchError(f"Could not load config from {config}. Check formatting") return parsed_config def construct_scheduler_args( sweep_config: dict[str, Any], queue: str, project: str, author: str | None = None, return_job: bool = False, ) -> list[str] | dict[str, str] | None: """Construct sweep scheduler args. logs error and returns None if misconfigured, otherwise returns args as a dict if is_job else a list of strings. """ job = sweep_config.get("job") image_uri = sweep_config.get("image_uri") if not job and not image_uri: # don't allow empty string wandb.termerror( "No 'job' nor 'image_uri' top-level key found in sweep config, exactly one is required for a launch-sweep" ) return None elif job and image_uri: wandb.termerror( "Sweep config has both 'job' and 'image_uri' but a launch-sweep can use only one" ) return None # if scheduler is a job, return args as dict if return_job: args_dict: dict[str, str] = { "sweep_id": "WANDB_SWEEP_ID", "queue": queue, "project": project, } if job: args_dict["job"] = job elif image_uri: args_dict["image_uri"] = image_uri if author: args_dict["author"] = author return args_dict # scheduler uses cli commands, pass args as param list args = [ "--queue", f"{queue!r}", "--project", f"{project!r}", ] if author: args += [ "--author", f"{author!r}", ] if job: args += [ "--job", f"{job!r}", ] elif image_uri: args += ["--image_uri", image_uri] return args def create_sweep_command(command: list | None = None) -> list: """Return sweep command, filling in environment variable macros.""" # Start from default sweep command command = command or DEFAULT_SWEEP_COMMAND for i, chunk in enumerate(command): # Replace environment variable macros # Search a str(chunk), but allow matches to be of any (ex: int) type if SWEEP_COMMAND_ENV_VAR_REGEX.search(str(chunk)): # Replace from backwards forwards matches = list(SWEEP_COMMAND_ENV_VAR_REGEX.finditer(chunk)) for m in matches[::-1]: # Default to just leaving as is if environment variable does not exist _var: str = os.environ.get(m.group(1), m.group(1)) command[i] = f"{command[i][: m.start()]}{_var}{command[i][m.end() :]}" return command def create_sweep_command_args(command: dict) -> dict[str, Any]: """Create various formats of command arguments for the agent. Raises: ValueError: improperly formatted command dict """ if "args" not in command: raise ValueError(f'No "args" found in command: {command}') # four different formats of command args # (1) standard command line flags (e.g. --foo=bar) flags: list[str] = [] # (2) flags without hyphens (e.g. foo=bar) flags_no_hyphens: list[str] = [] # (3) flags with false booleans omitted (e.g. --foo) flags_no_booleans: list[str] = [] # (4) flags as a dictionary (used for constructing a json) flags_dict: dict[str, Any] = {} # (5) flags without equals (e.g. --foo bar) args_no_equals: list[str] = [] # (6) flags for hydra append config value (e.g. +foo=bar) flags_append_hydra: list[str] = [] # (7) flags for hydra override config value (e.g. ++foo=bar) flags_override_hydra: list[str] = [] for param, config in command["args"].items(): # allow 'None' as a valid value, but error if no value is found try: _value: Any = config["value"] except KeyError: raise ValueError(f'No "value" found for command["args"]["{param}"]') _flag: str = f"{param}={_value}" flags.append("--" + _flag) flags_no_hyphens.append(_flag) args_no_equals += [f"--{param}", str(_value)] flags_append_hydra.append("+" + _flag) flags_override_hydra.append("++" + _flag) if isinstance(_value, bool): # omit flags if they are boolean and false if _value: flags_no_booleans.append("--" + param) else: flags_no_booleans.append("--" + _flag) flags_dict[param] = _value return { "args": flags, "args_no_equals": args_no_equals, "args_no_hyphens": flags_no_hyphens, "args_no_boolean_flags": flags_no_booleans, "args_json": [json.dumps(flags_dict)], "args_dict": flags_dict, "args_append_hydra": flags_append_hydra, "args_override_hydra": flags_override_hydra, } def make_launch_sweep_entrypoint( args: dict[str, Any], command: list[str] | None ) -> tuple[list[str] | None, Any]: """Use args dict from create_sweep_command_args to construct entrypoint. If replace is True, remove macros from entrypoint, fill them in with args and then return the args in separate return value. """ if not command: return None, None entry_point = create_sweep_command(command) macro_args = {} for macro in args: mstr = "${" + macro + "}" if mstr in entry_point: idx = entry_point.index(mstr) # only supports 1 macro per entrypoint macro_args = args[macro] entry_point = entry_point[:idx] + entry_point[idx + 1 :] if len(entry_point) == 0: return None, macro_args return entry_point, macro_args def check_job_exists(public_api: PublicApi, job: str | None) -> bool: """Check if the job exists using the public api. Returns: True if no job is passed, or if the job exists. Returns: False if the job is misformatted or doesn't exist. """ if not job: return True try: public_api.job(job) except Exception as e: wandb.termerror(f"Failed to load job. {e}") return False return True def get_previous_args( run_spec: dict[str, Any], ) -> tuple[dict[str, Any], dict[str, Any]]: """Parse through previous scheduler run_spec. returns scheduler_args and settings. """ scheduler_args = ( run_spec.get("overrides", {}).get("run_config", {}).get("scheduler", {}) ) # also pipe through top level resource setup if run_spec.get("resource"): scheduler_args["resource"] = run_spec["resource"] if run_spec.get("resource_args"): scheduler_args["resource_args"] = run_spec["resource_args"] settings = run_spec.get("overrides", {}).get("run_config", {}).get("settings", {}) return scheduler_args, settings