| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- from __future__ import annotations
- import json
- import logging
- import os
- from typing import Any
- import wandb
- from wandb.errors import Error
- from wandb.util import load_yaml
- from . import filesystem
- logger = logging.getLogger("wandb")
- class ConfigError(Error):
- pass
- def dict_from_proto_list(obj_list):
- d = dict()
- for item in obj_list:
- d[item.key] = dict(desc=None, value=json.loads(item.value_json))
- return d
- def dict_strip_value_dict(config_dict):
- d = dict()
- for k, v in config_dict.items():
- d[k] = v["value"]
- return d
- def dict_no_value_from_proto_list(obj_list):
- d = dict()
- for item in obj_list:
- possible_dict = json.loads(item.value_json)
- if not isinstance(possible_dict, dict) or "value" not in possible_dict:
- continue
- d[item.key] = possible_dict["value"]
- return d
- # TODO(jhr): these functions should go away once we merge jobspec PR
- def save_config_file_from_dict(config_filename, config_dict):
- import yaml
- s = b"wandb_version: 1"
- if config_dict: # adding an empty dictionary here causes a parse error
- s += b"\n\n" + yaml.dump(
- config_dict,
- Dumper=yaml.SafeDumper,
- default_flow_style=False,
- allow_unicode=True,
- encoding="utf-8",
- sort_keys=False,
- )
- data = s.decode("utf-8")
- filesystem.mkdir_exists_ok(os.path.dirname(config_filename))
- with open(config_filename, "w") as conf_file:
- conf_file.write(data)
- def dict_from_config_file(
- filename: str, must_exist: bool = False
- ) -> dict[str, Any] | None:
- import yaml
- if not os.path.exists(filename):
- if must_exist:
- raise ConfigError(f"config file {filename} doesn't exist")
- logger.debug(f"no default config file found in {filename}")
- return None
- try:
- conf_file = open(filename)
- except OSError:
- raise ConfigError(f"Couldn't read config file: {filename}")
- try:
- loaded = load_yaml(conf_file)
- except yaml.parser.ParserError:
- raise ConfigError("Invalid YAML in config yaml")
- if loaded is None:
- wandb.termwarn(
- "Found an empty default config file (config-defaults.yaml). Proceeding with no defaults."
- )
- return None
- config_version = loaded.pop("wandb_version", None)
- if config_version is not None and config_version != 1:
- raise ConfigError("Unknown config version")
- data = dict()
- for k, v in loaded.items():
- data[k] = v["value"]
- return data
- def merge_dicts(dest: dict, src: dict) -> dict:
- """Recursively merge two dictionaries. Similar to Lodash's _.merge()."""
- for key, value in src.items():
- if isinstance(value, dict) and key in dest and isinstance(dest[key], dict):
- merge_dicts(dest[key], value)
- else:
- dest[key] = value
- return dest
|