dict.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. from typing import Any, List
  2. from omegaconf import AnyNode, Container, DictConfig, ListConfig
  3. from omegaconf._utils import Marker
  4. from omegaconf.basecontainer import BaseContainer
  5. from omegaconf.errors import ConfigKeyError
  6. _DEFAULT_SELECT_MARKER_: Any = Marker("_DEFAULT_SELECT_MARKER_")
  7. def keys(
  8. key: str,
  9. _parent_: Container,
  10. ) -> ListConfig:
  11. from omegaconf import OmegaConf
  12. assert isinstance(_parent_, BaseContainer)
  13. in_dict = _get_and_validate_dict_input(
  14. key, parent=_parent_, resolver_name="oc.dict.keys"
  15. )
  16. ret = OmegaConf.create(list(in_dict.keys()), parent=_parent_)
  17. assert isinstance(ret, ListConfig)
  18. return ret
  19. def values(key: str, _root_: BaseContainer, _parent_: Container) -> ListConfig:
  20. assert isinstance(_parent_, BaseContainer)
  21. in_dict = _get_and_validate_dict_input(
  22. key, parent=_parent_, resolver_name="oc.dict.values"
  23. )
  24. content = in_dict._content
  25. assert isinstance(content, dict)
  26. ret = ListConfig([])
  27. if key.startswith("."):
  28. key = f".{key}" # extra dot to compensate for extra level of nesting within ret ListConfig
  29. for k in content:
  30. ref_node = AnyNode(f"${{{key}.{k!s}}}")
  31. ret.append(ref_node)
  32. # Finalize result by setting proper type and parent.
  33. element_type: Any = in_dict._metadata.element_type
  34. ret._metadata.element_type = element_type
  35. ret._metadata.ref_type = List[element_type]
  36. ret._set_parent(_parent_)
  37. return ret
  38. def _get_and_validate_dict_input(
  39. key: str,
  40. parent: BaseContainer,
  41. resolver_name: str,
  42. ) -> DictConfig:
  43. from omegaconf._impl import select_value
  44. if not isinstance(key, str):
  45. raise TypeError(
  46. f"`{resolver_name}` requires a string as input, but obtained `{key}` "
  47. f"of type: {type(key).__name__}"
  48. )
  49. in_dict = select_value(
  50. parent,
  51. key,
  52. throw_on_missing=True,
  53. absolute_key=True,
  54. default=_DEFAULT_SELECT_MARKER_,
  55. )
  56. if in_dict is _DEFAULT_SELECT_MARKER_:
  57. raise ConfigKeyError(f"Key not found: '{key}'")
  58. if not isinstance(in_dict, DictConfig):
  59. raise TypeError(
  60. f"`{resolver_name}` cannot be applied to objects of type: "
  61. f"{type(in_dict).__name__}"
  62. )
  63. return in_dict