wandb_helper.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import inspect
  2. import types
  3. from wandb.errors import UsageError
  4. from .lib import config_util
  5. def parse_config(params, exclude=None, include=None):
  6. if exclude and include:
  7. raise UsageError("Expected at most only one of exclude or include")
  8. if isinstance(params, str):
  9. params = config_util.dict_from_config_file(params, must_exist=True)
  10. params = _to_dict(params)
  11. if include:
  12. params = {key: value for key, value in params.items() if key in include}
  13. if exclude:
  14. params = {key: value for key, value in params.items() if key not in exclude}
  15. return params
  16. def _to_dict(params):
  17. if isinstance(params, dict):
  18. return params
  19. # Handle some cases where params is not a dictionary
  20. # by trying to convert it into a dictionary
  21. meta = inspect.getmodule(params)
  22. if meta:
  23. is_tf_flags_module = (
  24. isinstance(params, types.ModuleType)
  25. and meta.__name__ == "tensorflow.python.platform.flags"
  26. )
  27. if is_tf_flags_module or meta.__name__ == "absl.flags":
  28. params = params.FLAGS
  29. meta = inspect.getmodule(params)
  30. # newer tensorflow flags (post 1.4) uses absl.flags
  31. if meta and meta.__name__ == "absl.flags._flagvalues":
  32. params = {name: params[name].value for name in dir(params)}
  33. elif not hasattr(params, "__dict__"):
  34. raise TypeError("config must be a dict or have a __dict__ attribute.")
  35. elif "__flags" in vars(params):
  36. # for older tensorflow flags (pre 1.4)
  37. if not "__parsed" not in vars(params):
  38. params._parse_flags()
  39. params = vars(params)["__flags"]
  40. else:
  41. # params is a Namespace object (argparse)
  42. # or something else
  43. params = vars(params)
  44. # assume argparse Namespace
  45. return params