config_util.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from __future__ import annotations
  2. import json
  3. import logging
  4. import os
  5. from typing import Any
  6. import wandb
  7. from wandb.errors import Error
  8. from wandb.util import load_yaml
  9. from . import filesystem
  10. logger = logging.getLogger("wandb")
  11. class ConfigError(Error):
  12. pass
  13. def dict_from_proto_list(obj_list):
  14. d = dict()
  15. for item in obj_list:
  16. d[item.key] = dict(desc=None, value=json.loads(item.value_json))
  17. return d
  18. def dict_strip_value_dict(config_dict):
  19. d = dict()
  20. for k, v in config_dict.items():
  21. d[k] = v["value"]
  22. return d
  23. def dict_no_value_from_proto_list(obj_list):
  24. d = dict()
  25. for item in obj_list:
  26. possible_dict = json.loads(item.value_json)
  27. if not isinstance(possible_dict, dict) or "value" not in possible_dict:
  28. continue
  29. d[item.key] = possible_dict["value"]
  30. return d
  31. # TODO(jhr): these functions should go away once we merge jobspec PR
  32. def save_config_file_from_dict(config_filename, config_dict):
  33. import yaml
  34. s = b"wandb_version: 1"
  35. if config_dict: # adding an empty dictionary here causes a parse error
  36. s += b"\n\n" + yaml.dump(
  37. config_dict,
  38. Dumper=yaml.SafeDumper,
  39. default_flow_style=False,
  40. allow_unicode=True,
  41. encoding="utf-8",
  42. sort_keys=False,
  43. )
  44. data = s.decode("utf-8")
  45. filesystem.mkdir_exists_ok(os.path.dirname(config_filename))
  46. with open(config_filename, "w") as conf_file:
  47. conf_file.write(data)
  48. def dict_from_config_file(
  49. filename: str, must_exist: bool = False
  50. ) -> dict[str, Any] | None:
  51. import yaml
  52. if not os.path.exists(filename):
  53. if must_exist:
  54. raise ConfigError(f"config file {filename} doesn't exist")
  55. logger.debug(f"no default config file found in {filename}")
  56. return None
  57. try:
  58. conf_file = open(filename)
  59. except OSError:
  60. raise ConfigError(f"Couldn't read config file: {filename}")
  61. try:
  62. loaded = load_yaml(conf_file)
  63. except yaml.parser.ParserError:
  64. raise ConfigError("Invalid YAML in config yaml")
  65. if loaded is None:
  66. wandb.termwarn(
  67. "Found an empty default config file (config-defaults.yaml). Proceeding with no defaults."
  68. )
  69. return None
  70. config_version = loaded.pop("wandb_version", None)
  71. if config_version is not None and config_version != 1:
  72. raise ConfigError("Unknown config version")
  73. data = dict()
  74. for k, v in loaded.items():
  75. data[k] = v["value"]
  76. return data
  77. def merge_dicts(dest: dict, src: dict) -> dict:
  78. """Recursively merge two dictionaries. Similar to Lodash's _.merge()."""
  79. for key, value in src.items():
  80. if isinstance(value, dict) and key in dest and isinstance(dest[key], dict):
  81. merge_dicts(dest[key], value)
  82. else:
  83. dest[key] = value
  84. return dest