config.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import dataclasses
  2. from typing import Iterable
  3. def ensure_only_allowed_dataclass_keys_updated(
  4. dataclass: dataclasses.dataclass,
  5. allowed_keys: Iterable[str],
  6. ):
  7. """
  8. Validate dataclass by raising an exception if any key not included in
  9. ``allowed_keys`` differs from the default value.
  10. A ``ValueError`` will also be raised if any of the ``allowed_keys``
  11. is not present in ``dataclass.__dict__``.
  12. Args:
  13. dataclass: Dict or dataclass to check.
  14. allowed_keys: dataclass attribute keys that can have a value different than
  15. the default one.
  16. """
  17. default_data = dataclass.__class__()
  18. allowed_keys = set(allowed_keys)
  19. # TODO: split keys_not_in_dict validation to a separate function.
  20. keys_not_in_dict = [key for key in allowed_keys if key not in default_data.__dict__]
  21. if keys_not_in_dict:
  22. raise ValueError(
  23. f"Key(s) {keys_not_in_dict} are not present in "
  24. f"{dataclass.__class__.__name__}. "
  25. "Remove them from `allowed_keys`. "
  26. f"Valid keys: {list(default_data.__dict__.keys())}"
  27. )
  28. # These keys should not have been updated in the `dataclass` object
  29. prohibited_keys = set(default_data.__dict__) - allowed_keys
  30. bad_keys = [
  31. key
  32. for key in prohibited_keys
  33. if dataclass.__dict__[key] != default_data.__dict__[key]
  34. ]
  35. if bad_keys:
  36. raise ValueError(
  37. f"Key(s) {bad_keys} are not allowed to be updated in the current context. "
  38. "Remove them from the dataclass."
  39. )