_impl.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from typing import Any
  2. from omegaconf import MISSING, Container, DictConfig, ListConfig, Node, ValueNode
  3. from omegaconf.errors import ConfigTypeError, InterpolationToMissingValueError
  4. from ._utils import _DEFAULT_MARKER_, _get_value
  5. def _resolve_container_value(cfg: Container, key: Any) -> None:
  6. node = cfg._get_child(key)
  7. assert isinstance(node, Node)
  8. if node._is_interpolation():
  9. try:
  10. resolved = node._dereference_node()
  11. except InterpolationToMissingValueError:
  12. node._set_value(MISSING)
  13. else:
  14. if isinstance(resolved, Container):
  15. _resolve(resolved)
  16. if isinstance(resolved, Container) and isinstance(node, ValueNode):
  17. cfg[key] = resolved
  18. else:
  19. node._set_value(_get_value(resolved))
  20. else:
  21. _resolve(node)
  22. def _resolve(cfg: Node) -> Node:
  23. assert isinstance(cfg, Node)
  24. if cfg._is_interpolation():
  25. try:
  26. resolved = cfg._dereference_node()
  27. except InterpolationToMissingValueError:
  28. cfg._set_value(MISSING)
  29. else:
  30. cfg._set_value(resolved._value())
  31. if isinstance(cfg, DictConfig):
  32. for k in cfg.keys():
  33. _resolve_container_value(cfg, k)
  34. elif isinstance(cfg, ListConfig):
  35. for i in range(len(cfg)):
  36. _resolve_container_value(cfg, i)
  37. return cfg
  38. def select_value(
  39. cfg: Container,
  40. key: str,
  41. *,
  42. default: Any = _DEFAULT_MARKER_,
  43. throw_on_resolution_failure: bool = True,
  44. throw_on_missing: bool = False,
  45. absolute_key: bool = False,
  46. ) -> Any:
  47. node = select_node(
  48. cfg=cfg,
  49. key=key,
  50. throw_on_resolution_failure=throw_on_resolution_failure,
  51. throw_on_missing=throw_on_missing,
  52. absolute_key=absolute_key,
  53. )
  54. node_not_found = node is None
  55. if node_not_found or node._is_missing():
  56. if default is not _DEFAULT_MARKER_:
  57. return default
  58. else:
  59. return None
  60. return _get_value(node)
  61. def select_node(
  62. cfg: Container,
  63. key: str,
  64. *,
  65. throw_on_resolution_failure: bool = True,
  66. throw_on_missing: bool = False,
  67. absolute_key: bool = False,
  68. ) -> Any:
  69. try:
  70. # for non relative keys, the interpretation can be:
  71. # 1. relative to cfg
  72. # 2. relative to the config root
  73. # This is controlled by the absolute_key flag. By default, such keys are relative to cfg.
  74. if not absolute_key and not key.startswith("."):
  75. key = f".{key}"
  76. cfg, key = cfg._resolve_key_and_root(key)
  77. _root, _last_key, node = cfg._select_impl(
  78. key,
  79. throw_on_missing=throw_on_missing,
  80. throw_on_resolution_failure=throw_on_resolution_failure,
  81. )
  82. except ConfigTypeError:
  83. return None
  84. return node