wandb_config.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. """config."""
  2. from __future__ import annotations
  3. import logging
  4. import wandb
  5. from wandb.util import (
  6. _is_artifact_representation,
  7. check_dict_contains_nested_artifact,
  8. json_friendly_val,
  9. )
  10. from . import wandb_helper
  11. from .lib import config_util
  12. logger = logging.getLogger("wandb")
  13. # TODO(jhr): consider a callback for persisting changes?
  14. # if this is done right we might make sure this is pickle-able
  15. # we might be able to do this on other objects like Run?
  16. class Config:
  17. """Config object.
  18. Config objects are intended to hold all of the hyperparameters associated
  19. with a wandb run and are saved with the run object when `wandb.init` is
  20. called.
  21. We recommend setting the config once when initializing your run by passing
  22. the `config` parameter to `init`:
  23. ```
  24. wandb.init(config=my_config_dict)
  25. ```
  26. You can create a file called `config-defaults.yaml`, and it will
  27. automatically be loaded as each run's config. You can also pass the name
  28. of the file as the `config` parameter to `init`:
  29. ```
  30. wandb.init(config="my_config.yaml")
  31. ```
  32. See https://docs.wandb.ai/models/track/config#file-based-configs.
  33. Examples:
  34. Basic usage
  35. ```
  36. with wandb.init(config={"epochs": 4}) as run:
  37. for x in range(run.config.epochs):
  38. # train
  39. ```
  40. Nested values
  41. ```
  42. with wandb.init(config={"train": {"epochs": 4}}) as run:
  43. for x in range(run.config["train"]["epochs"]):
  44. # train
  45. ```
  46. Using absl flags
  47. ```
  48. flags.DEFINE_string("model", None, "model to run") # name, default, help
  49. with wandb.init() as run:
  50. run.config.update(flags.FLAGS) # adds all absl flags to config
  51. ```
  52. Argparse flags
  53. ```python
  54. with wandb.init(config={"epochs": 4}) as run:
  55. parser = argparse.ArgumentParser()
  56. parser.add_argument(
  57. "-b",
  58. "--batch-size",
  59. type=int,
  60. default=8,
  61. metavar="N",
  62. help="input batch size for training (default: 8)",
  63. )
  64. args = parser.parse_args()
  65. run.config.update(args)
  66. ```
  67. Using TensorFlow flags (deprecated in tensorflow v2)
  68. ```python
  69. flags = tf.app.flags
  70. flags.DEFINE_string("data_dir", "/tmp/data")
  71. flags.DEFINE_integer("batch_size", 128, "Batch size.")
  72. with wandb.init() as run:
  73. run.config.update(flags.FLAGS)
  74. ```
  75. """
  76. def __init__(self):
  77. object.__setattr__(self, "_items", dict())
  78. object.__setattr__(self, "_locked", dict())
  79. object.__setattr__(self, "_users", dict())
  80. object.__setattr__(self, "_users_inv", dict())
  81. object.__setattr__(self, "_users_cnt", 0)
  82. object.__setattr__(self, "_callback", None)
  83. object.__setattr__(self, "_settings", None)
  84. object.__setattr__(self, "_artifact_callback", None)
  85. self._load_defaults()
  86. def _set_callback(self, cb):
  87. object.__setattr__(self, "_callback", cb)
  88. def _set_artifact_callback(self, cb):
  89. object.__setattr__(self, "_artifact_callback", cb)
  90. def _set_settings(self, settings):
  91. object.__setattr__(self, "_settings", settings)
  92. def __repr__(self):
  93. return str(dict(self))
  94. def keys(self):
  95. return [k for k in self._items if not k.startswith("_")]
  96. def _as_dict(self):
  97. return self._items
  98. def as_dict(self):
  99. # TODO: add telemetry, deprecate, then remove
  100. return dict(self)
  101. def __getitem__(self, key):
  102. return self._items[key]
  103. def __iter__(self):
  104. return iter(self._items)
  105. def _check_locked(self, key, ignore_locked=False) -> bool:
  106. locked = self._locked.get(key)
  107. if locked is not None:
  108. locked_user = self._users_inv[locked]
  109. if not ignore_locked:
  110. wandb.termwarn(
  111. f"Config item '{key}' was locked by '{locked_user}' (ignored update)."
  112. )
  113. return True
  114. return False
  115. def __setitem__(self, key, val):
  116. if self._check_locked(key):
  117. return
  118. with wandb.sdk.lib.telemetry.context() as tel:
  119. tel.feature.set_config_item = True
  120. self._raise_value_error_on_nested_artifact(val, nested=True)
  121. key, val = self._sanitize(key, val)
  122. self._items[key] = val
  123. logger.info("config set %s = %s - %s", key, val, self._callback)
  124. if self._callback:
  125. self._callback(key=key, val=val)
  126. def items(self):
  127. return [(k, v) for k, v in self._items.items() if not k.startswith("_")]
  128. __setattr__ = __setitem__
  129. def __getattr__(self, key):
  130. try:
  131. return self.__getitem__(key)
  132. except KeyError as ke:
  133. raise AttributeError(
  134. f"{self.__class__!r} object has no attribute {key!r}"
  135. ) from ke
  136. def __contains__(self, key):
  137. return key in self._items
  138. def _update(self, d, allow_val_change=None, ignore_locked=None):
  139. parsed_dict = wandb_helper.parse_config(d)
  140. locked_keys = set()
  141. for key in list(parsed_dict):
  142. if self._check_locked(key, ignore_locked=ignore_locked):
  143. locked_keys.add(key)
  144. sanitized = self._sanitize_dict(
  145. parsed_dict, allow_val_change, ignore_keys=locked_keys
  146. )
  147. self._items.update(sanitized)
  148. return sanitized
  149. def update(self, d, allow_val_change=None):
  150. sanitized = self._update(d, allow_val_change)
  151. if self._callback:
  152. self._callback(data=sanitized)
  153. def get(self, *args):
  154. return self._items.get(*args)
  155. def persist(self):
  156. """Call the callback if it's set."""
  157. if self._callback:
  158. self._callback(data=self._as_dict())
  159. def setdefaults(self, d):
  160. d = wandb_helper.parse_config(d)
  161. # strip out keys already configured
  162. d = {k: v for k, v in d.items() if k not in self._items}
  163. d = self._sanitize_dict(d)
  164. self._items.update(d)
  165. if self._callback:
  166. self._callback(data=d)
  167. def _get_user_id(self, user) -> int:
  168. if user not in self._users:
  169. self._users[user] = self._users_cnt
  170. self._users_inv[self._users_cnt] = user
  171. object.__setattr__(self, "_users_cnt", self._users_cnt + 1)
  172. return self._users[user]
  173. def update_locked(self, d, user=None, _allow_val_change=None):
  174. """Shallow-update config with `d` and lock config updates on d's keys."""
  175. num = self._get_user_id(user)
  176. for k, v in d.items():
  177. k, v = self._sanitize(k, v, allow_val_change=_allow_val_change)
  178. self._locked[k] = num
  179. self._items[k] = v
  180. if self._callback:
  181. self._callback(data=d)
  182. def merge_locked(self, d, user=None, _allow_val_change=None):
  183. """Recursively merge-update config with `d` and lock config updates on d's keys."""
  184. num = self._get_user_id(user)
  185. callback_d = {}
  186. for k, v in d.items():
  187. k, v = self._sanitize(k, v, allow_val_change=_allow_val_change)
  188. self._locked[k] = num
  189. if (
  190. k in self._items
  191. and isinstance(self._items[k], dict)
  192. and isinstance(v, dict)
  193. ):
  194. self._items[k] = config_util.merge_dicts(self._items[k], v)
  195. else:
  196. self._items[k] = v
  197. callback_d[k] = self._items[k]
  198. if self._callback:
  199. self._callback(data=callback_d)
  200. def _load_defaults(self):
  201. conf_dict = config_util.dict_from_config_file("config-defaults.yaml")
  202. if conf_dict is not None:
  203. self.update(conf_dict)
  204. def _sanitize_dict(
  205. self,
  206. config_dict,
  207. allow_val_change=None,
  208. ignore_keys: set | None = None,
  209. ):
  210. sanitized = {}
  211. self._raise_value_error_on_nested_artifact(config_dict)
  212. for k, v in config_dict.items():
  213. if ignore_keys and k in ignore_keys:
  214. continue
  215. k, v = self._sanitize(k, v, allow_val_change)
  216. sanitized[k] = v
  217. return sanitized
  218. def _sanitize(self, key, val, allow_val_change=None):
  219. # TODO: enable WBValues in the config in the future
  220. # refuse all WBValues which is all Media and Histograms
  221. if isinstance(val, wandb.sdk.data_types.base_types.wb_value.WBValue):
  222. raise TypeError("WBValue objects cannot be added to the run config")
  223. # Let jupyter change config freely by default
  224. if self._settings and self._settings._jupyter and allow_val_change is None:
  225. allow_val_change = True
  226. # We always normalize keys by stripping '-'
  227. key = key.strip("-")
  228. if _is_artifact_representation(val):
  229. val = self._artifact_callback(key, val)
  230. # if the user inserts an artifact into the config
  231. if not isinstance(val, wandb.Artifact):
  232. val = json_friendly_val(val)
  233. if (
  234. (not allow_val_change)
  235. and (key in self._items)
  236. and (val != self._items[key])
  237. ):
  238. raise config_util.ConfigError(
  239. f'Attempted to change value of key "{key}" '
  240. f"from {self._items[key]} to {val}\n"
  241. "If you really want to do this, pass"
  242. " allow_val_change=True to config.update()"
  243. )
  244. return key, val
  245. def _raise_value_error_on_nested_artifact(self, v, nested=False):
  246. # we can't swap nested artifacts because their root key can be locked by other values
  247. # best if we don't allow nested artifacts until we can lock nested keys in the config
  248. if isinstance(v, dict) and check_dict_contains_nested_artifact(v, nested):
  249. raise ValueError(
  250. "Instances of wandb.Artifact can only be top level keys in"
  251. " a run's config"
  252. )
  253. class ConfigStatic:
  254. def __init__(self, config):
  255. object.__setattr__(self, "__dict__", dict(config))
  256. def __setattr__(self, name, value):
  257. raise AttributeError("Error: run.config_static is a readonly object")
  258. def __setitem__(self, key, val):
  259. raise AttributeError("Error: run.config_static is a readonly object")
  260. def keys(self):
  261. return self.__dict__.keys()
  262. def __getitem__(self, key):
  263. return self.__dict__[key]
  264. def __str__(self):
  265. return str(self.__dict__)