| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327 |
- """config."""
- from __future__ import annotations
- import logging
- import wandb
- from wandb.util import (
- _is_artifact_representation,
- check_dict_contains_nested_artifact,
- json_friendly_val,
- )
- from . import wandb_helper
- from .lib import config_util
- logger = logging.getLogger("wandb")
- # TODO(jhr): consider a callback for persisting changes?
- # if this is done right we might make sure this is pickle-able
- # we might be able to do this on other objects like Run?
- class Config:
- """Config object.
- Config objects are intended to hold all of the hyperparameters associated
- with a wandb run and are saved with the run object when `wandb.init` is
- called.
- We recommend setting the config once when initializing your run by passing
- the `config` parameter to `init`:
- ```
- wandb.init(config=my_config_dict)
- ```
- You can create a file called `config-defaults.yaml`, and it will
- automatically be loaded as each run's config. You can also pass the name
- of the file as the `config` parameter to `init`:
- ```
- wandb.init(config="my_config.yaml")
- ```
- See https://docs.wandb.ai/models/track/config#file-based-configs.
- Examples:
- Basic usage
- ```
- with wandb.init(config={"epochs": 4}) as run:
- for x in range(run.config.epochs):
- # train
- ```
- Nested values
- ```
- with wandb.init(config={"train": {"epochs": 4}}) as run:
- for x in range(run.config["train"]["epochs"]):
- # train
- ```
- Using absl flags
- ```
- flags.DEFINE_string("model", None, "model to run") # name, default, help
- with wandb.init() as run:
- run.config.update(flags.FLAGS) # adds all absl flags to config
- ```
- Argparse flags
- ```python
- with wandb.init(config={"epochs": 4}) as run:
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "-b",
- "--batch-size",
- type=int,
- default=8,
- metavar="N",
- help="input batch size for training (default: 8)",
- )
- args = parser.parse_args()
- run.config.update(args)
- ```
- Using TensorFlow flags (deprecated in tensorflow v2)
- ```python
- flags = tf.app.flags
- flags.DEFINE_string("data_dir", "/tmp/data")
- flags.DEFINE_integer("batch_size", 128, "Batch size.")
- with wandb.init() as run:
- run.config.update(flags.FLAGS)
- ```
- """
- def __init__(self):
- object.__setattr__(self, "_items", dict())
- object.__setattr__(self, "_locked", dict())
- object.__setattr__(self, "_users", dict())
- object.__setattr__(self, "_users_inv", dict())
- object.__setattr__(self, "_users_cnt", 0)
- object.__setattr__(self, "_callback", None)
- object.__setattr__(self, "_settings", None)
- object.__setattr__(self, "_artifact_callback", None)
- self._load_defaults()
- def _set_callback(self, cb):
- object.__setattr__(self, "_callback", cb)
- def _set_artifact_callback(self, cb):
- object.__setattr__(self, "_artifact_callback", cb)
- def _set_settings(self, settings):
- object.__setattr__(self, "_settings", settings)
- def __repr__(self):
- return str(dict(self))
- def keys(self):
- return [k for k in self._items if not k.startswith("_")]
- def _as_dict(self):
- return self._items
- def as_dict(self):
- # TODO: add telemetry, deprecate, then remove
- return dict(self)
- def __getitem__(self, key):
- return self._items[key]
- def __iter__(self):
- return iter(self._items)
- def _check_locked(self, key, ignore_locked=False) -> bool:
- locked = self._locked.get(key)
- if locked is not None:
- locked_user = self._users_inv[locked]
- if not ignore_locked:
- wandb.termwarn(
- f"Config item '{key}' was locked by '{locked_user}' (ignored update)."
- )
- return True
- return False
- def __setitem__(self, key, val):
- if self._check_locked(key):
- return
- with wandb.sdk.lib.telemetry.context() as tel:
- tel.feature.set_config_item = True
- self._raise_value_error_on_nested_artifact(val, nested=True)
- key, val = self._sanitize(key, val)
- self._items[key] = val
- logger.info("config set %s = %s - %s", key, val, self._callback)
- if self._callback:
- self._callback(key=key, val=val)
- def items(self):
- return [(k, v) for k, v in self._items.items() if not k.startswith("_")]
- __setattr__ = __setitem__
- def __getattr__(self, key):
- try:
- return self.__getitem__(key)
- except KeyError as ke:
- raise AttributeError(
- f"{self.__class__!r} object has no attribute {key!r}"
- ) from ke
- def __contains__(self, key):
- return key in self._items
- def _update(self, d, allow_val_change=None, ignore_locked=None):
- parsed_dict = wandb_helper.parse_config(d)
- locked_keys = set()
- for key in list(parsed_dict):
- if self._check_locked(key, ignore_locked=ignore_locked):
- locked_keys.add(key)
- sanitized = self._sanitize_dict(
- parsed_dict, allow_val_change, ignore_keys=locked_keys
- )
- self._items.update(sanitized)
- return sanitized
- def update(self, d, allow_val_change=None):
- sanitized = self._update(d, allow_val_change)
- if self._callback:
- self._callback(data=sanitized)
- def get(self, *args):
- return self._items.get(*args)
- def persist(self):
- """Call the callback if it's set."""
- if self._callback:
- self._callback(data=self._as_dict())
- def setdefaults(self, d):
- d = wandb_helper.parse_config(d)
- # strip out keys already configured
- d = {k: v for k, v in d.items() if k not in self._items}
- d = self._sanitize_dict(d)
- self._items.update(d)
- if self._callback:
- self._callback(data=d)
- def _get_user_id(self, user) -> int:
- if user not in self._users:
- self._users[user] = self._users_cnt
- self._users_inv[self._users_cnt] = user
- object.__setattr__(self, "_users_cnt", self._users_cnt + 1)
- return self._users[user]
- def update_locked(self, d, user=None, _allow_val_change=None):
- """Shallow-update config with `d` and lock config updates on d's keys."""
- num = self._get_user_id(user)
- for k, v in d.items():
- k, v = self._sanitize(k, v, allow_val_change=_allow_val_change)
- self._locked[k] = num
- self._items[k] = v
- if self._callback:
- self._callback(data=d)
- def merge_locked(self, d, user=None, _allow_val_change=None):
- """Recursively merge-update config with `d` and lock config updates on d's keys."""
- num = self._get_user_id(user)
- callback_d = {}
- for k, v in d.items():
- k, v = self._sanitize(k, v, allow_val_change=_allow_val_change)
- self._locked[k] = num
- if (
- k in self._items
- and isinstance(self._items[k], dict)
- and isinstance(v, dict)
- ):
- self._items[k] = config_util.merge_dicts(self._items[k], v)
- else:
- self._items[k] = v
- callback_d[k] = self._items[k]
- if self._callback:
- self._callback(data=callback_d)
- def _load_defaults(self):
- conf_dict = config_util.dict_from_config_file("config-defaults.yaml")
- if conf_dict is not None:
- self.update(conf_dict)
- def _sanitize_dict(
- self,
- config_dict,
- allow_val_change=None,
- ignore_keys: set | None = None,
- ):
- sanitized = {}
- self._raise_value_error_on_nested_artifact(config_dict)
- for k, v in config_dict.items():
- if ignore_keys and k in ignore_keys:
- continue
- k, v = self._sanitize(k, v, allow_val_change)
- sanitized[k] = v
- return sanitized
- def _sanitize(self, key, val, allow_val_change=None):
- # TODO: enable WBValues in the config in the future
- # refuse all WBValues which is all Media and Histograms
- if isinstance(val, wandb.sdk.data_types.base_types.wb_value.WBValue):
- raise TypeError("WBValue objects cannot be added to the run config")
- # Let jupyter change config freely by default
- if self._settings and self._settings._jupyter and allow_val_change is None:
- allow_val_change = True
- # We always normalize keys by stripping '-'
- key = key.strip("-")
- if _is_artifact_representation(val):
- val = self._artifact_callback(key, val)
- # if the user inserts an artifact into the config
- if not isinstance(val, wandb.Artifact):
- val = json_friendly_val(val)
- if (
- (not allow_val_change)
- and (key in self._items)
- and (val != self._items[key])
- ):
- raise config_util.ConfigError(
- f'Attempted to change value of key "{key}" '
- f"from {self._items[key]} to {val}\n"
- "If you really want to do this, pass"
- " allow_val_change=True to config.update()"
- )
- return key, val
- def _raise_value_error_on_nested_artifact(self, v, nested=False):
- # we can't swap nested artifacts because their root key can be locked by other values
- # best if we don't allow nested artifacts until we can lock nested keys in the config
- if isinstance(v, dict) and check_dict_contains_nested_artifact(v, nested):
- raise ValueError(
- "Instances of wandb.Artifact can only be top level keys in"
- " a run's config"
- )
- class ConfigStatic:
- def __init__(self, config):
- object.__setattr__(self, "__dict__", dict(config))
- def __setattr__(self, name, value):
- raise AttributeError("Error: run.config_static is a readonly object")
- def __setitem__(self, key, val):
- raise AttributeError("Error: run.config_static is a readonly object")
- def keys(self):
- return self.__dict__.keys()
- def __getitem__(self, key):
- return self.__dict__[key]
- def __str__(self):
- return str(self.__dict__)
|