omegaconf.py 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157
  1. """OmegaConf module"""
  2. import copy
  3. import inspect
  4. import io
  5. import os
  6. import pathlib
  7. import sys
  8. import warnings
  9. from collections import defaultdict
  10. from contextlib import contextmanager
  11. from enum import Enum
  12. from textwrap import dedent
  13. from typing import (
  14. IO,
  15. Any,
  16. Callable,
  17. Dict,
  18. Generator,
  19. Iterable,
  20. List,
  21. Optional,
  22. Set,
  23. Tuple,
  24. Type,
  25. Union,
  26. overload,
  27. )
  28. import yaml
  29. from . import DictConfig, DictKeyType, ListConfig
  30. from ._utils import (
  31. _DEFAULT_MARKER_,
  32. _ensure_container,
  33. _get_value,
  34. format_and_raise,
  35. get_dict_key_value_types,
  36. get_list_element_type,
  37. get_omega_conf_dumper,
  38. get_type_of,
  39. is_attr_class,
  40. is_dataclass,
  41. is_dict_annotation,
  42. is_int,
  43. is_list_annotation,
  44. is_primitive_container,
  45. is_primitive_dict,
  46. is_primitive_list,
  47. is_structured_config,
  48. is_tuple_annotation,
  49. is_union_annotation,
  50. nullcontext,
  51. split_key,
  52. type_str,
  53. )
  54. from .base import Box, Container, Node, SCMode, UnionNode
  55. from .basecontainer import BaseContainer
  56. from .errors import (
  57. MissingMandatoryValue,
  58. OmegaConfBaseException,
  59. UnsupportedInterpolationType,
  60. ValidationError,
  61. )
  62. from .nodes import (
  63. AnyNode,
  64. BooleanNode,
  65. BytesNode,
  66. EnumNode,
  67. FloatNode,
  68. IntegerNode,
  69. PathNode,
  70. StringNode,
  71. ValueNode,
  72. )
  73. MISSING: Any = "???"
  74. Resolver = Callable[..., Any]
  75. def II(interpolation: str) -> Any:
  76. """
  77. Equivalent to ``${interpolation}``
  78. :param interpolation:
  79. :return: input ``${node}`` with type Any
  80. """
  81. return "${" + interpolation + "}"
  82. def SI(interpolation: str) -> Any:
  83. """
  84. Use this for String interpolation, for example ``"http://${host}:${port}"``
  85. :param interpolation: interpolation string
  86. :return: input interpolation with type ``Any``
  87. """
  88. return interpolation
  89. def register_default_resolvers() -> None:
  90. from omegaconf.resolvers import oc
  91. OmegaConf.register_new_resolver("oc.create", oc.create)
  92. OmegaConf.register_new_resolver("oc.decode", oc.decode)
  93. OmegaConf.register_new_resolver("oc.deprecated", oc.deprecated)
  94. OmegaConf.register_new_resolver("oc.env", oc.env)
  95. OmegaConf.register_new_resolver("oc.select", oc.select)
  96. OmegaConf.register_new_resolver("oc.dict.keys", oc.dict.keys)
  97. OmegaConf.register_new_resolver("oc.dict.values", oc.dict.values)
  98. class OmegaConf:
  99. """OmegaConf primary class"""
  100. def __init__(self) -> None:
  101. raise NotImplementedError("Use one of the static construction functions")
  102. @staticmethod
  103. def structured(
  104. obj: Any,
  105. parent: Optional[BaseContainer] = None,
  106. flags: Optional[Dict[str, bool]] = None,
  107. ) -> Any:
  108. return OmegaConf.create(obj, parent, flags)
  109. @staticmethod
  110. @overload
  111. def create(
  112. obj: str,
  113. parent: Optional[BaseContainer] = None,
  114. flags: Optional[Dict[str, bool]] = None,
  115. ) -> Union[DictConfig, ListConfig]:
  116. ...
  117. @staticmethod
  118. @overload
  119. def create(
  120. obj: Union[List[Any], Tuple[Any, ...]],
  121. parent: Optional[BaseContainer] = None,
  122. flags: Optional[Dict[str, bool]] = None,
  123. ) -> ListConfig:
  124. ...
  125. @staticmethod
  126. @overload
  127. def create(
  128. obj: DictConfig,
  129. parent: Optional[BaseContainer] = None,
  130. flags: Optional[Dict[str, bool]] = None,
  131. ) -> DictConfig:
  132. ...
  133. @staticmethod
  134. @overload
  135. def create(
  136. obj: ListConfig,
  137. parent: Optional[BaseContainer] = None,
  138. flags: Optional[Dict[str, bool]] = None,
  139. ) -> ListConfig:
  140. ...
  141. @staticmethod
  142. @overload
  143. def create(
  144. obj: Optional[Dict[Any, Any]] = None,
  145. parent: Optional[BaseContainer] = None,
  146. flags: Optional[Dict[str, bool]] = None,
  147. ) -> DictConfig:
  148. ...
  149. @staticmethod
  150. def create( # noqa F811
  151. obj: Any = _DEFAULT_MARKER_,
  152. parent: Optional[BaseContainer] = None,
  153. flags: Optional[Dict[str, bool]] = None,
  154. ) -> Union[DictConfig, ListConfig]:
  155. return OmegaConf._create_impl(
  156. obj=obj,
  157. parent=parent,
  158. flags=flags,
  159. )
  160. @staticmethod
  161. def load(file_: Union[str, pathlib.Path, IO[Any]]) -> Union[DictConfig, ListConfig]:
  162. from ._utils import get_yaml_loader
  163. if isinstance(file_, (str, pathlib.Path)):
  164. with io.open(os.path.abspath(file_), "r", encoding="utf-8") as f:
  165. obj = yaml.load(f, Loader=get_yaml_loader())
  166. elif getattr(file_, "read", None):
  167. obj = yaml.load(file_, Loader=get_yaml_loader())
  168. else:
  169. raise TypeError("Unexpected file type")
  170. if obj is not None and not isinstance(obj, (list, dict, str)):
  171. raise IOError( # pragma: no cover
  172. f"Invalid loaded object type: {type(obj).__name__}"
  173. )
  174. ret: Union[DictConfig, ListConfig]
  175. if obj is None:
  176. ret = OmegaConf.create()
  177. else:
  178. ret = OmegaConf.create(obj)
  179. return ret
  180. @staticmethod
  181. def save(
  182. config: Any, f: Union[str, pathlib.Path, IO[Any]], resolve: bool = False
  183. ) -> None:
  184. """
  185. Save as configuration object to a file
  186. :param config: omegaconf.Config object (DictConfig or ListConfig).
  187. :param f: filename or file object
  188. :param resolve: True to save a resolved config (defaults to False)
  189. """
  190. if is_dataclass(config) or is_attr_class(config):
  191. config = OmegaConf.create(config)
  192. data = OmegaConf.to_yaml(config, resolve=resolve)
  193. if isinstance(f, (str, pathlib.Path)):
  194. with io.open(os.path.abspath(f), "w", encoding="utf-8") as file:
  195. file.write(data)
  196. elif hasattr(f, "write"):
  197. f.write(data)
  198. f.flush()
  199. else:
  200. raise TypeError("Unexpected file type")
  201. @staticmethod
  202. def from_cli(args_list: Optional[List[str]] = None) -> DictConfig:
  203. if args_list is None:
  204. # Skip program name
  205. args_list = sys.argv[1:]
  206. return OmegaConf.from_dotlist(args_list)
  207. @staticmethod
  208. def from_dotlist(dotlist: List[str]) -> DictConfig:
  209. """
  210. Creates config from the content sys.argv or from the specified args list of not None
  211. :param dotlist: A list of dotlist-style strings, e.g. ``["foo.bar=1", "baz=qux"]``.
  212. :return: A ``DictConfig`` object created from the dotlist.
  213. """
  214. conf = OmegaConf.create()
  215. conf.merge_with_dotlist(dotlist)
  216. return conf
  217. @staticmethod
  218. def merge(
  219. *configs: Union[
  220. DictConfig,
  221. ListConfig,
  222. Dict[DictKeyType, Any],
  223. List[Any],
  224. Tuple[Any, ...],
  225. Any,
  226. ],
  227. ) -> Union[ListConfig, DictConfig]:
  228. """
  229. Merge a list of previously created configs into a single one
  230. :param configs: Input configs
  231. :return: the merged config object.
  232. """
  233. assert len(configs) > 0
  234. target = copy.deepcopy(configs[0])
  235. target = _ensure_container(target)
  236. assert isinstance(target, (DictConfig, ListConfig))
  237. with flag_override(target, "readonly", False):
  238. target.merge_with(*configs[1:])
  239. turned_readonly = target._get_flag("readonly") is True
  240. if turned_readonly:
  241. OmegaConf.set_readonly(target, True)
  242. return target
  243. @staticmethod
  244. def unsafe_merge(
  245. *configs: Union[
  246. DictConfig,
  247. ListConfig,
  248. Dict[DictKeyType, Any],
  249. List[Any],
  250. Tuple[Any, ...],
  251. Any,
  252. ],
  253. ) -> Union[ListConfig, DictConfig]:
  254. """
  255. Merge a list of previously created configs into a single one
  256. This is much faster than OmegaConf.merge() as the input configs are not copied.
  257. However, the input configs must not be used after this operation as will become inconsistent.
  258. :param configs: Input configs
  259. :return: the merged config object.
  260. """
  261. assert len(configs) > 0
  262. target = configs[0]
  263. target = _ensure_container(target)
  264. assert isinstance(target, (DictConfig, ListConfig))
  265. with flag_override(
  266. target, ["readonly", "no_deepcopy_set_nodes"], [False, True]
  267. ):
  268. target.merge_with(*configs[1:])
  269. turned_readonly = target._get_flag("readonly") is True
  270. if turned_readonly:
  271. OmegaConf.set_readonly(target, True)
  272. return target
  273. @staticmethod
  274. def register_resolver(name: str, resolver: Resolver) -> None:
  275. warnings.warn(
  276. dedent(
  277. """\
  278. register_resolver() is deprecated.
  279. See https://github.com/omry/omegaconf/issues/426 for migration instructions.
  280. """
  281. ),
  282. stacklevel=2,
  283. )
  284. return OmegaConf.legacy_register_resolver(name, resolver)
  285. # This function will eventually be deprecated and removed.
  286. @staticmethod
  287. def legacy_register_resolver(name: str, resolver: Resolver) -> None:
  288. assert callable(resolver), "resolver must be callable"
  289. # noinspection PyProtectedMember
  290. assert (
  291. name not in BaseContainer._resolvers
  292. ), f"resolver '{name}' is already registered"
  293. def resolver_wrapper(
  294. config: BaseContainer,
  295. parent: BaseContainer,
  296. node: Node,
  297. args: Tuple[Any, ...],
  298. args_str: Tuple[str, ...],
  299. ) -> Any:
  300. cache = OmegaConf.get_cache(config)[name]
  301. # "Un-escape " spaces and commas.
  302. args_unesc = [x.replace(r"\ ", " ").replace(r"\,", ",") for x in args_str]
  303. # Nested interpolations behave in a potentially surprising way with
  304. # legacy resolvers (they remain as strings, e.g., "${foo}"). If any
  305. # input looks like an interpolation we thus raise an exception.
  306. try:
  307. bad_arg = next(i for i in args_unesc if "${" in i)
  308. except StopIteration:
  309. pass
  310. else:
  311. raise ValueError(
  312. f"Resolver '{name}' was called with argument '{bad_arg}' that appears "
  313. f"to be an interpolation. Nested interpolations are not supported for "
  314. f"resolvers registered with `[legacy_]register_resolver()`, please use "
  315. f"`register_new_resolver()` instead (see "
  316. f"https://github.com/omry/omegaconf/issues/426 for migration instructions)."
  317. )
  318. key = args_str
  319. val = cache[key] if key in cache else resolver(*args_unesc)
  320. cache[key] = val
  321. return val
  322. # noinspection PyProtectedMember
  323. BaseContainer._resolvers[name] = resolver_wrapper
  324. @staticmethod
  325. def register_new_resolver(
  326. name: str,
  327. resolver: Resolver,
  328. *,
  329. replace: bool = False,
  330. use_cache: bool = False,
  331. ) -> None:
  332. """
  333. Register a resolver.
  334. :param name: Name of the resolver.
  335. :param resolver: Callable whose arguments are provided in the interpolation,
  336. e.g., with ${foo:x,0,${y.z}} these arguments are respectively "x" (str),
  337. 0 (int) and the value of ``y.z``.
  338. :param replace: If set to ``False`` (default), then a ``ValueError`` is raised if
  339. an existing resolver has already been registered with the same name.
  340. If set to ``True``, then the new resolver replaces the previous one.
  341. NOTE: The cache on existing config objects is not affected, use
  342. ``OmegaConf.clear_cache(cfg)`` to clear it.
  343. :param use_cache: Whether the resolver's outputs should be cached. The cache is
  344. based only on the string literals representing the resolver arguments, e.g.,
  345. ${foo:${bar}} will always return the same value regardless of the value of
  346. ``bar`` if the cache is enabled for ``foo``.
  347. """
  348. if not callable(resolver):
  349. raise TypeError("resolver must be callable")
  350. if not name:
  351. raise ValueError("cannot use an empty resolver name")
  352. if not replace and OmegaConf.has_resolver(name):
  353. raise ValueError(f"resolver '{name}' is already registered")
  354. try:
  355. sig: Optional[inspect.Signature] = inspect.signature(resolver)
  356. except ValueError:
  357. sig = None
  358. def _should_pass(special: str) -> bool:
  359. ret = sig is not None and special in sig.parameters
  360. if ret and use_cache:
  361. raise ValueError(
  362. f"use_cache=True is incompatible with functions that receive the {special}"
  363. )
  364. return ret
  365. pass_parent = _should_pass("_parent_")
  366. pass_node = _should_pass("_node_")
  367. pass_root = _should_pass("_root_")
  368. def resolver_wrapper(
  369. config: BaseContainer,
  370. parent: Container,
  371. node: Node,
  372. args: Tuple[Any, ...],
  373. args_str: Tuple[str, ...],
  374. ) -> Any:
  375. if use_cache:
  376. cache = OmegaConf.get_cache(config)[name]
  377. try:
  378. return cache[args_str]
  379. except KeyError:
  380. pass
  381. # Call resolver.
  382. kwargs: Dict[str, Node] = {}
  383. if pass_parent:
  384. kwargs["_parent_"] = parent
  385. if pass_node:
  386. kwargs["_node_"] = node
  387. if pass_root:
  388. kwargs["_root_"] = config
  389. ret = resolver(*args, **kwargs)
  390. if use_cache:
  391. cache[args_str] = ret
  392. return ret
  393. # noinspection PyProtectedMember
  394. BaseContainer._resolvers[name] = resolver_wrapper
  395. @classmethod
  396. def has_resolver(cls, name: str) -> bool:
  397. return cls._get_resolver(name) is not None
  398. # noinspection PyProtectedMember
  399. @staticmethod
  400. def clear_resolvers() -> None:
  401. """
  402. Clear(remove) all OmegaConf resolvers, then re-register OmegaConf's default resolvers.
  403. """
  404. BaseContainer._resolvers = {}
  405. register_default_resolvers()
  406. @classmethod
  407. def clear_resolver(cls, name: str) -> bool:
  408. """
  409. Clear(remove) any resolver only if it exists.
  410. Returns a bool: True if resolver is removed and False if not removed.
  411. .. warning:
  412. This method can remove deafult resolvers as well.
  413. :param name: Name of the resolver.
  414. :return: A bool (``True`` if resolver is removed, ``False`` if not found before removing).
  415. """
  416. if cls.has_resolver(name):
  417. BaseContainer._resolvers.pop(name)
  418. return True
  419. else:
  420. # return False if resolver does not exist
  421. return False
  422. @staticmethod
  423. def get_cache(conf: BaseContainer) -> Dict[str, Any]:
  424. return conf._metadata.resolver_cache
  425. @staticmethod
  426. def set_cache(conf: BaseContainer, cache: Dict[str, Any]) -> None:
  427. conf._metadata.resolver_cache = copy.deepcopy(cache)
  428. @staticmethod
  429. def clear_cache(conf: BaseContainer) -> None:
  430. OmegaConf.set_cache(conf, defaultdict(dict, {}))
  431. @staticmethod
  432. def copy_cache(from_config: BaseContainer, to_config: BaseContainer) -> None:
  433. OmegaConf.set_cache(to_config, OmegaConf.get_cache(from_config))
  434. @staticmethod
  435. def set_readonly(conf: Node, value: Optional[bool]) -> None:
  436. # noinspection PyProtectedMember
  437. conf._set_flag("readonly", value)
  438. @staticmethod
  439. def is_readonly(conf: Node) -> Optional[bool]:
  440. # noinspection PyProtectedMember
  441. return conf._get_flag("readonly")
  442. @staticmethod
  443. def set_struct(conf: Container, value: Optional[bool]) -> None:
  444. # noinspection PyProtectedMember
  445. conf._set_flag("struct", value)
  446. @staticmethod
  447. def is_struct(conf: Container) -> Optional[bool]:
  448. # noinspection PyProtectedMember
  449. return conf._get_flag("struct")
  450. @staticmethod
  451. def masked_copy(conf: DictConfig, keys: Union[str, List[str]]) -> DictConfig:
  452. """
  453. Create a masked copy of of this config that contains a subset of the keys
  454. :param conf: DictConfig object
  455. :param keys: keys to preserve in the copy
  456. :return: The masked ``DictConfig`` object.
  457. """
  458. from .dictconfig import DictConfig
  459. if not isinstance(conf, DictConfig):
  460. raise ValueError("masked_copy is only supported for DictConfig")
  461. if isinstance(keys, str):
  462. keys = [keys]
  463. content = {key: value for key, value in conf.items_ex(resolve=False, keys=keys)}
  464. return DictConfig(content=content)
  465. @staticmethod
  466. def to_container(
  467. cfg: Any,
  468. *,
  469. resolve: bool = False,
  470. throw_on_missing: bool = False,
  471. enum_to_str: bool = False,
  472. structured_config_mode: SCMode = SCMode.DICT,
  473. ) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]:
  474. """
  475. Resursively converts an OmegaConf config to a primitive container (dict or list).
  476. :param cfg: the config to convert
  477. :param resolve: True to resolve all values
  478. :param throw_on_missing: When True, raise MissingMandatoryValue if any missing values are present.
  479. When False (the default), replace missing values with the string "???" in the output container.
  480. :param enum_to_str: True to convert Enum keys and values to strings
  481. :param structured_config_mode: Specify how Structured Configs (DictConfigs backed by a dataclass) are handled.
  482. - By default (``structured_config_mode=SCMode.DICT``) structured configs are converted to plain dicts.
  483. - If ``structured_config_mode=SCMode.DICT_CONFIG``, structured config nodes will remain as DictConfig.
  484. - If ``structured_config_mode=SCMode.INSTANTIATE``, this function will instantiate structured configs
  485. (DictConfigs backed by a dataclass), by creating an instance of the underlying dataclass.
  486. See also OmegaConf.to_object.
  487. :return: A dict or a list representing this config as a primitive container.
  488. """
  489. if not OmegaConf.is_config(cfg):
  490. raise ValueError(
  491. f"Input cfg is not an OmegaConf config object ({type_str(type(cfg))})"
  492. )
  493. return BaseContainer._to_content(
  494. cfg,
  495. resolve=resolve,
  496. throw_on_missing=throw_on_missing,
  497. enum_to_str=enum_to_str,
  498. structured_config_mode=structured_config_mode,
  499. )
  500. @staticmethod
  501. def to_object(cfg: Any) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]:
  502. """
  503. Resursively converts an OmegaConf config to a primitive container (dict or list).
  504. Any DictConfig objects backed by dataclasses or attrs classes are instantiated
  505. as instances of those backing classes.
  506. This is an alias for OmegaConf.to_container(..., resolve=True, throw_on_missing=True,
  507. structured_config_mode=SCMode.INSTANTIATE)
  508. :param cfg: the config to convert
  509. :return: A dict or a list or dataclass representing this config.
  510. """
  511. return OmegaConf.to_container(
  512. cfg=cfg,
  513. resolve=True,
  514. throw_on_missing=True,
  515. enum_to_str=False,
  516. structured_config_mode=SCMode.INSTANTIATE,
  517. )
  518. @staticmethod
  519. def is_missing(cfg: Any, key: DictKeyType) -> bool:
  520. assert isinstance(cfg, Container)
  521. try:
  522. node = cfg._get_child(key)
  523. if node is None:
  524. return False
  525. assert isinstance(node, Node)
  526. return node._is_missing()
  527. except (UnsupportedInterpolationType, KeyError, AttributeError):
  528. return False
  529. @staticmethod
  530. def is_interpolation(node: Any, key: Optional[Union[int, str]] = None) -> bool:
  531. if key is not None:
  532. assert isinstance(node, Container)
  533. target = node._get_child(key)
  534. else:
  535. target = node
  536. if target is not None:
  537. assert isinstance(target, Node)
  538. return target._is_interpolation()
  539. return False
  540. @staticmethod
  541. def is_list(obj: Any) -> bool:
  542. from . import ListConfig
  543. return isinstance(obj, ListConfig)
  544. @staticmethod
  545. def is_dict(obj: Any) -> bool:
  546. from . import DictConfig
  547. return isinstance(obj, DictConfig)
  548. @staticmethod
  549. def is_config(obj: Any) -> bool:
  550. from . import Container
  551. return isinstance(obj, Container)
  552. @staticmethod
  553. def get_type(obj: Any, key: Optional[str] = None) -> Optional[Type[Any]]:
  554. if key is not None:
  555. c = obj._get_child(key)
  556. else:
  557. c = obj
  558. return OmegaConf._get_obj_type(c)
  559. @staticmethod
  560. def select(
  561. cfg: Container,
  562. key: str,
  563. *,
  564. default: Any = _DEFAULT_MARKER_,
  565. throw_on_resolution_failure: bool = True,
  566. throw_on_missing: bool = False,
  567. ) -> Any:
  568. """
  569. :param cfg: Config node to select from
  570. :param key: Key to select
  571. :param default: Default value to return if key is not found
  572. :param throw_on_resolution_failure: Raise an exception if an interpolation
  573. resolution error occurs, otherwise return None
  574. :param throw_on_missing: Raise an exception if an attempt to select a missing key (with the value '???')
  575. is made, otherwise return None
  576. :return: selected value or None if not found.
  577. """
  578. from ._impl import select_value
  579. try:
  580. return select_value(
  581. cfg=cfg,
  582. key=key,
  583. default=default,
  584. throw_on_resolution_failure=throw_on_resolution_failure,
  585. throw_on_missing=throw_on_missing,
  586. )
  587. except Exception as e:
  588. format_and_raise(node=cfg, key=key, value=None, cause=e, msg=str(e))
  589. @staticmethod
  590. def update(
  591. cfg: Container,
  592. key: str,
  593. value: Any = None,
  594. *,
  595. merge: bool = True,
  596. force_add: bool = False,
  597. ) -> None:
  598. """
  599. Updates a dot separated key sequence to a value
  600. :param cfg: input config to update
  601. :param key: key to update (can be a dot separated path)
  602. :param value: value to set, if value if a list or a dict it will be merged or set
  603. depending on merge_config_values
  604. :param merge: If value is a dict or a list, True (default) to merge
  605. into the destination, False to replace the destination.
  606. :param force_add: insert the entire path regardless of Struct flag or Structured Config nodes.
  607. """
  608. split = split_key(key)
  609. root = cfg
  610. for i in range(len(split) - 1):
  611. k = split[i]
  612. # if next_root is a primitive (string, int etc) replace it with an empty map
  613. next_root, key_ = _select_one(root, k, throw_on_missing=False)
  614. if not isinstance(next_root, Container):
  615. if force_add:
  616. with flag_override(root, "struct", False):
  617. root[key_] = {}
  618. else:
  619. root[key_] = {}
  620. root = root[key_]
  621. last = split[-1]
  622. assert isinstance(
  623. root, Container
  624. ), f"Unexpected type for root: {type(root).__name__}"
  625. last_key: Union[str, int] = last
  626. if isinstance(root, ListConfig):
  627. last_key = int(last)
  628. ctx = flag_override(root, "struct", False) if force_add else nullcontext()
  629. with ctx:
  630. if merge and (OmegaConf.is_config(value) or is_primitive_container(value)):
  631. assert isinstance(root, BaseContainer)
  632. node = root._get_child(last_key)
  633. if OmegaConf.is_config(node):
  634. assert isinstance(node, BaseContainer)
  635. node.merge_with(value)
  636. return
  637. if OmegaConf.is_dict(root):
  638. assert isinstance(last_key, str)
  639. root.__setattr__(last_key, value)
  640. elif OmegaConf.is_list(root):
  641. assert isinstance(last_key, int)
  642. root.__setitem__(last_key, value)
  643. else:
  644. assert False
  645. @staticmethod
  646. def to_yaml(cfg: Any, *, resolve: bool = False, sort_keys: bool = False) -> str:
  647. """
  648. returns a yaml dump of this config object.
  649. :param cfg: Config object, Structured Config type or instance
  650. :param resolve: if True, will return a string with the interpolations resolved, otherwise
  651. interpolations are preserved
  652. :param sort_keys: If True, will print dict keys in sorted order. default False.
  653. :return: A string containing the yaml representation.
  654. """
  655. cfg = _ensure_container(cfg)
  656. container = OmegaConf.to_container(cfg, resolve=resolve, enum_to_str=True)
  657. return yaml.dump( # type: ignore
  658. container,
  659. default_flow_style=False,
  660. allow_unicode=True,
  661. sort_keys=sort_keys,
  662. Dumper=get_omega_conf_dumper(),
  663. )
  664. @staticmethod
  665. def resolve(cfg: Container) -> None:
  666. """
  667. Resolves all interpolations in the given config object in-place.
  668. :param cfg: An OmegaConf container (DictConfig, ListConfig)
  669. Raises a ValueError if the input object is not an OmegaConf container.
  670. """
  671. import omegaconf._impl
  672. if not OmegaConf.is_config(cfg):
  673. # Since this function is mutating the input object in-place, it doesn't make sense to
  674. # auto-convert the input object to an OmegaConf container
  675. raise ValueError(
  676. f"Invalid config type ({type(cfg).__name__}), expected an OmegaConf Container"
  677. )
  678. omegaconf._impl._resolve(cfg)
  679. @staticmethod
  680. def missing_keys(cfg: Any) -> Set[str]:
  681. """
  682. Returns a set of missing keys in a dotlist style.
  683. :param cfg: An ``OmegaConf.Container``,
  684. or a convertible object via ``OmegaConf.create`` (dict, list, ...).
  685. :return: set of strings of the missing keys.
  686. :raises ValueError: On input not representing a config.
  687. """
  688. cfg = _ensure_container(cfg)
  689. missings: Set[str] = set()
  690. def gather(_cfg: Container) -> None:
  691. itr: Iterable[Any]
  692. if isinstance(_cfg, ListConfig):
  693. itr = range(len(_cfg))
  694. else:
  695. itr = _cfg
  696. for key in itr:
  697. if OmegaConf.is_missing(_cfg, key):
  698. missings.add(_cfg._get_full_key(key))
  699. elif OmegaConf.is_config(_cfg[key]):
  700. gather(_cfg[key])
  701. gather(cfg)
  702. return missings
  703. # === private === #
  704. @staticmethod
  705. def _create_impl( # noqa F811
  706. obj: Any = _DEFAULT_MARKER_,
  707. parent: Optional[BaseContainer] = None,
  708. flags: Optional[Dict[str, bool]] = None,
  709. ) -> Union[DictConfig, ListConfig]:
  710. try:
  711. from ._utils import get_yaml_loader
  712. from .dictconfig import DictConfig
  713. from .listconfig import ListConfig
  714. if obj is _DEFAULT_MARKER_:
  715. obj = {}
  716. if isinstance(obj, str):
  717. obj = yaml.load(obj, Loader=get_yaml_loader())
  718. if obj is None:
  719. return OmegaConf.create({}, parent=parent, flags=flags)
  720. elif isinstance(obj, str):
  721. return OmegaConf.create({obj: None}, parent=parent, flags=flags)
  722. else:
  723. assert isinstance(obj, (list, dict))
  724. return OmegaConf.create(obj, parent=parent, flags=flags)
  725. else:
  726. if (
  727. is_primitive_dict(obj)
  728. or OmegaConf.is_dict(obj)
  729. or is_structured_config(obj)
  730. or obj is None
  731. ):
  732. if isinstance(obj, DictConfig):
  733. return DictConfig(
  734. content=obj,
  735. parent=parent,
  736. ref_type=obj._metadata.ref_type,
  737. is_optional=obj._metadata.optional,
  738. key_type=obj._metadata.key_type,
  739. element_type=obj._metadata.element_type,
  740. flags=flags,
  741. )
  742. else:
  743. obj_type = OmegaConf.get_type(obj)
  744. key_type, element_type = get_dict_key_value_types(obj_type)
  745. return DictConfig(
  746. content=obj,
  747. parent=parent,
  748. key_type=key_type,
  749. element_type=element_type,
  750. flags=flags,
  751. )
  752. elif is_primitive_list(obj) or OmegaConf.is_list(obj):
  753. if isinstance(obj, ListConfig):
  754. return ListConfig(
  755. content=obj,
  756. parent=parent,
  757. element_type=obj._metadata.element_type,
  758. ref_type=obj._metadata.ref_type,
  759. is_optional=obj._metadata.optional,
  760. flags=flags,
  761. )
  762. else:
  763. obj_type = OmegaConf.get_type(obj)
  764. element_type = get_list_element_type(obj_type)
  765. return ListConfig(
  766. content=obj,
  767. parent=parent,
  768. element_type=element_type,
  769. ref_type=Any,
  770. is_optional=True,
  771. flags=flags,
  772. )
  773. else:
  774. if isinstance(obj, type):
  775. raise ValidationError(
  776. f"Input class '{obj.__name__}' is not a structured config. "
  777. "did you forget to decorate it as a dataclass?"
  778. )
  779. else:
  780. raise ValidationError(
  781. f"Object of unsupported type: '{type(obj).__name__}'"
  782. )
  783. except OmegaConfBaseException as e:
  784. format_and_raise(node=None, key=None, value=None, msg=str(e), cause=e)
  785. assert False
  786. @staticmethod
  787. def _get_obj_type(c: Any) -> Optional[Type[Any]]:
  788. if is_structured_config(c):
  789. return get_type_of(c)
  790. elif c is None:
  791. return None
  792. elif isinstance(c, DictConfig):
  793. if c._is_none():
  794. return None
  795. elif c._is_missing():
  796. return None
  797. else:
  798. if is_structured_config(c._metadata.object_type):
  799. return c._metadata.object_type
  800. else:
  801. return dict
  802. elif isinstance(c, ListConfig):
  803. return list
  804. elif isinstance(c, ValueNode):
  805. return type(c._value())
  806. elif isinstance(c, UnionNode):
  807. return type(_get_value(c))
  808. elif isinstance(c, dict):
  809. return dict
  810. elif isinstance(c, (list, tuple)):
  811. return list
  812. else:
  813. return get_type_of(c)
  814. @staticmethod
  815. def _get_resolver(
  816. name: str,
  817. ) -> Optional[
  818. Callable[
  819. [Container, Container, Node, Tuple[Any, ...], Tuple[str, ...]],
  820. Any,
  821. ]
  822. ]:
  823. # noinspection PyProtectedMember
  824. return (
  825. BaseContainer._resolvers[name] if name in BaseContainer._resolvers else None
  826. )
  827. # register all default resolvers
  828. register_default_resolvers()
  829. @contextmanager
  830. def flag_override(
  831. config: Node,
  832. names: Union[List[str], str],
  833. values: Union[List[Optional[bool]], Optional[bool]],
  834. ) -> Generator[Node, None, None]:
  835. if isinstance(names, str):
  836. names = [names]
  837. if values is None or isinstance(values, bool):
  838. values = [values]
  839. prev_states = [config._get_node_flag(name) for name in names]
  840. try:
  841. config._set_flag(names, values)
  842. yield config
  843. finally:
  844. config._set_flag(names, prev_states)
  845. @contextmanager
  846. def read_write(config: Node) -> Generator[Node, None, None]:
  847. prev_state = config._get_node_flag("readonly")
  848. try:
  849. OmegaConf.set_readonly(config, False)
  850. yield config
  851. finally:
  852. OmegaConf.set_readonly(config, prev_state)
  853. @contextmanager
  854. def open_dict(config: Container) -> Generator[Container, None, None]:
  855. prev_state = config._get_node_flag("struct")
  856. try:
  857. OmegaConf.set_struct(config, False)
  858. yield config
  859. finally:
  860. OmegaConf.set_struct(config, prev_state)
  861. # === private === #
  862. def _node_wrap(
  863. parent: Optional[Box],
  864. is_optional: bool,
  865. value: Any,
  866. key: Any,
  867. ref_type: Any = Any,
  868. ) -> Node:
  869. node: Node
  870. if is_dict_annotation(ref_type) or (is_primitive_dict(value) and ref_type is Any):
  871. key_type, element_type = get_dict_key_value_types(ref_type)
  872. node = DictConfig(
  873. content=value,
  874. key=key,
  875. parent=parent,
  876. ref_type=ref_type,
  877. is_optional=is_optional,
  878. key_type=key_type,
  879. element_type=element_type,
  880. )
  881. elif (is_list_annotation(ref_type) or is_tuple_annotation(ref_type)) or (
  882. type(value) in (list, tuple) and ref_type is Any
  883. ):
  884. element_type = get_list_element_type(ref_type)
  885. node = ListConfig(
  886. content=value,
  887. key=key,
  888. parent=parent,
  889. is_optional=is_optional,
  890. element_type=element_type,
  891. ref_type=ref_type,
  892. )
  893. elif is_structured_config(ref_type) or is_structured_config(value):
  894. key_type, element_type = get_dict_key_value_types(value)
  895. node = DictConfig(
  896. ref_type=ref_type,
  897. is_optional=is_optional,
  898. content=value,
  899. key=key,
  900. parent=parent,
  901. key_type=key_type,
  902. element_type=element_type,
  903. )
  904. elif is_union_annotation(ref_type):
  905. node = UnionNode(
  906. content=value,
  907. ref_type=ref_type,
  908. is_optional=is_optional,
  909. key=key,
  910. parent=parent,
  911. )
  912. elif ref_type == Any or ref_type is None:
  913. node = AnyNode(value=value, key=key, parent=parent)
  914. elif isinstance(ref_type, type) and issubclass(ref_type, Enum):
  915. node = EnumNode(
  916. enum_type=ref_type,
  917. value=value,
  918. key=key,
  919. parent=parent,
  920. is_optional=is_optional,
  921. )
  922. elif ref_type == int:
  923. node = IntegerNode(value=value, key=key, parent=parent, is_optional=is_optional)
  924. elif ref_type == float:
  925. node = FloatNode(value=value, key=key, parent=parent, is_optional=is_optional)
  926. elif ref_type == bool:
  927. node = BooleanNode(value=value, key=key, parent=parent, is_optional=is_optional)
  928. elif ref_type == str:
  929. node = StringNode(value=value, key=key, parent=parent, is_optional=is_optional)
  930. elif ref_type == bytes:
  931. node = BytesNode(value=value, key=key, parent=parent, is_optional=is_optional)
  932. elif ref_type == pathlib.Path:
  933. node = PathNode(value=value, key=key, parent=parent, is_optional=is_optional)
  934. else:
  935. if parent is not None and parent._get_flag("allow_objects") is True:
  936. if type(value) in (list, tuple):
  937. node = ListConfig(
  938. content=value,
  939. key=key,
  940. parent=parent,
  941. ref_type=ref_type,
  942. is_optional=is_optional,
  943. )
  944. elif is_primitive_dict(value):
  945. node = DictConfig(
  946. content=value,
  947. key=key,
  948. parent=parent,
  949. ref_type=ref_type,
  950. is_optional=is_optional,
  951. )
  952. else:
  953. node = AnyNode(value=value, key=key, parent=parent)
  954. else:
  955. raise ValidationError(f"Unexpected type annotation: {type_str(ref_type)}")
  956. return node
  957. def _maybe_wrap(
  958. ref_type: Any,
  959. key: Any,
  960. value: Any,
  961. is_optional: bool,
  962. parent: Optional[BaseContainer],
  963. ) -> Node:
  964. # if already a node, update key and parent and return as is.
  965. # NOTE: that this mutate the input node!
  966. if isinstance(value, Node):
  967. value._set_key(key)
  968. value._set_parent(parent)
  969. return value
  970. else:
  971. return _node_wrap(
  972. ref_type=ref_type,
  973. parent=parent,
  974. is_optional=is_optional,
  975. value=value,
  976. key=key,
  977. )
  978. def _select_one(
  979. c: Container, key: str, throw_on_missing: bool, throw_on_type_error: bool = True
  980. ) -> Tuple[Optional[Node], Union[str, int]]:
  981. from .dictconfig import DictConfig
  982. from .listconfig import ListConfig
  983. ret_key: Union[str, int] = key
  984. assert isinstance(c, Container), f"Unexpected type: {c}"
  985. if c._is_none():
  986. return None, ret_key
  987. if isinstance(c, DictConfig):
  988. assert isinstance(ret_key, str)
  989. val = c._get_child(ret_key, validate_access=False)
  990. elif isinstance(c, ListConfig):
  991. assert isinstance(ret_key, str)
  992. if not is_int(ret_key):
  993. if throw_on_type_error:
  994. raise TypeError(
  995. f"Index '{ret_key}' ({type(ret_key).__name__}) is not an int"
  996. )
  997. else:
  998. val = None
  999. else:
  1000. ret_key = int(ret_key)
  1001. if ret_key < 0 or ret_key + 1 > len(c):
  1002. val = None
  1003. else:
  1004. val = c._get_child(ret_key)
  1005. else:
  1006. assert False
  1007. if val is not None:
  1008. assert isinstance(val, Node)
  1009. if val._is_missing():
  1010. if throw_on_missing:
  1011. raise MissingMandatoryValue(
  1012. f"Missing mandatory value: {c._get_full_key(ret_key)}"
  1013. )
  1014. else:
  1015. return val, ret_key
  1016. assert val is None or isinstance(val, Node)
  1017. return val, ret_key