basecontainer.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916
  1. import copy
  2. import sys
  3. from abc import ABC, abstractmethod
  4. from enum import Enum
  5. from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Union
  6. import yaml
  7. from ._utils import (
  8. _DEFAULT_MARKER_,
  9. ValueKind,
  10. _ensure_container,
  11. _get_value,
  12. _is_interpolation,
  13. _is_missing_value,
  14. _is_none,
  15. _is_special,
  16. _resolve_optional,
  17. get_structured_config_data,
  18. get_type_hint,
  19. get_value_kind,
  20. get_yaml_loader,
  21. is_container_annotation,
  22. is_dict_annotation,
  23. is_list_annotation,
  24. is_primitive_dict,
  25. is_primitive_type_annotation,
  26. is_structured_config,
  27. is_tuple_annotation,
  28. is_union_annotation,
  29. )
  30. from .base import (
  31. Box,
  32. Container,
  33. ContainerMetadata,
  34. DictKeyType,
  35. Node,
  36. SCMode,
  37. UnionNode,
  38. )
  39. from .errors import (
  40. ConfigCycleDetectedException,
  41. ConfigTypeError,
  42. InterpolationResolutionError,
  43. KeyValidationError,
  44. MissingMandatoryValue,
  45. OmegaConfBaseException,
  46. ReadonlyConfigError,
  47. ValidationError,
  48. )
  49. if TYPE_CHECKING:
  50. from .dictconfig import DictConfig # pragma: no cover
  51. class BaseContainer(Container, ABC):
  52. _resolvers: ClassVar[Dict[str, Any]] = {}
  53. def __init__(self, parent: Optional[Box], metadata: ContainerMetadata):
  54. if not (parent is None or isinstance(parent, Box)):
  55. raise ConfigTypeError("Parent type is not omegaconf.Box")
  56. super().__init__(parent=parent, metadata=metadata)
  57. def _get_child(
  58. self,
  59. key: Any,
  60. validate_access: bool = True,
  61. validate_key: bool = True,
  62. throw_on_missing_value: bool = False,
  63. throw_on_missing_key: bool = False,
  64. ) -> Union[Optional[Node], List[Optional[Node]]]:
  65. """Like _get_node, passing through to the nearest concrete Node."""
  66. child = self._get_node(
  67. key=key,
  68. validate_access=validate_access,
  69. validate_key=validate_key,
  70. throw_on_missing_value=throw_on_missing_value,
  71. throw_on_missing_key=throw_on_missing_key,
  72. )
  73. if isinstance(child, UnionNode) and not _is_special(child):
  74. value = child._value()
  75. assert isinstance(value, Node) and not isinstance(value, UnionNode)
  76. child = value
  77. return child
  78. def _resolve_with_default(
  79. self,
  80. key: Union[DictKeyType, int],
  81. value: Node,
  82. default_value: Any = _DEFAULT_MARKER_,
  83. ) -> Any:
  84. """returns the value with the specified key, like obj.key and obj['key']"""
  85. if _is_missing_value(value):
  86. if default_value is not _DEFAULT_MARKER_:
  87. return default_value
  88. raise MissingMandatoryValue("Missing mandatory value: $FULL_KEY")
  89. resolved_node = self._maybe_resolve_interpolation(
  90. parent=self,
  91. key=key,
  92. value=value,
  93. throw_on_resolution_failure=True,
  94. )
  95. return _get_value(resolved_node)
  96. def __str__(self) -> str:
  97. return self.__repr__()
  98. def __repr__(self) -> str:
  99. if self.__dict__["_content"] is None:
  100. return "None"
  101. elif self._is_interpolation() or self._is_missing():
  102. v = self.__dict__["_content"]
  103. return f"'{v}'"
  104. else:
  105. return self.__dict__["_content"].__repr__() # type: ignore
  106. # Support pickle
  107. def __getstate__(self) -> Dict[str, Any]:
  108. dict_copy = copy.copy(self.__dict__)
  109. # no need to serialize the flags cache, it can be re-constructed later
  110. dict_copy.pop("_flags_cache", None)
  111. dict_copy["_metadata"] = copy.copy(dict_copy["_metadata"])
  112. ref_type = self._metadata.ref_type
  113. if is_container_annotation(ref_type):
  114. if is_dict_annotation(ref_type):
  115. dict_copy["_metadata"].ref_type = Dict
  116. elif is_list_annotation(ref_type):
  117. dict_copy["_metadata"].ref_type = List
  118. else:
  119. assert False
  120. if sys.version_info < (3, 7): # pragma: no cover
  121. element_type = self._metadata.element_type
  122. if is_union_annotation(element_type):
  123. raise OmegaConfBaseException(
  124. "Serializing structured configs with `Union` element type requires python >= 3.7"
  125. )
  126. return dict_copy
  127. # Support pickle
  128. def __setstate__(self, d: Dict[str, Any]) -> None:
  129. from omegaconf import DictConfig
  130. from omegaconf._utils import is_generic_dict, is_generic_list
  131. if isinstance(self, DictConfig):
  132. key_type = d["_metadata"].key_type
  133. # backward compatibility to load OmegaConf 2.0 configs
  134. if key_type is None:
  135. key_type = Any
  136. d["_metadata"].key_type = key_type
  137. element_type = d["_metadata"].element_type
  138. # backward compatibility to load OmegaConf 2.0 configs
  139. if element_type is None:
  140. element_type = Any
  141. d["_metadata"].element_type = element_type
  142. ref_type = d["_metadata"].ref_type
  143. if is_container_annotation(ref_type):
  144. if is_generic_dict(ref_type):
  145. d["_metadata"].ref_type = Dict[key_type, element_type] # type: ignore
  146. elif is_generic_list(ref_type):
  147. d["_metadata"].ref_type = List[element_type] # type: ignore
  148. else:
  149. assert False
  150. d["_flags_cache"] = None
  151. self.__dict__.update(d)
  152. @abstractmethod
  153. def __delitem__(self, key: Any) -> None:
  154. ...
  155. def __len__(self) -> int:
  156. if self._is_none() or self._is_missing() or self._is_interpolation():
  157. return 0
  158. content = self.__dict__["_content"]
  159. return len(content)
  160. def merge_with_cli(self) -> None:
  161. args_list = sys.argv[1:]
  162. self.merge_with_dotlist(args_list)
  163. def merge_with_dotlist(self, dotlist: List[str]) -> None:
  164. from omegaconf import OmegaConf
  165. def fail() -> None:
  166. raise ValueError("Input list must be a list or a tuple of strings")
  167. if not isinstance(dotlist, (list, tuple)):
  168. fail()
  169. for arg in dotlist:
  170. if not isinstance(arg, str):
  171. fail()
  172. idx = arg.find("=")
  173. if idx == -1:
  174. key = arg
  175. value = None
  176. else:
  177. key = arg[0:idx]
  178. value = arg[idx + 1 :]
  179. value = yaml.load(value, Loader=get_yaml_loader())
  180. OmegaConf.update(self, key, value)
  181. def is_empty(self) -> bool:
  182. """return true if config is empty"""
  183. return len(self.__dict__["_content"]) == 0
  184. @staticmethod
  185. def _to_content(
  186. conf: Container,
  187. resolve: bool,
  188. throw_on_missing: bool,
  189. enum_to_str: bool = False,
  190. structured_config_mode: SCMode = SCMode.DICT,
  191. ) -> Union[None, Any, str, Dict[DictKeyType, Any], List[Any]]:
  192. from omegaconf import MISSING, DictConfig, ListConfig
  193. def convert(val: Node) -> Any:
  194. value = val._value()
  195. if enum_to_str and isinstance(value, Enum):
  196. value = f"{value.name}"
  197. return value
  198. def get_node_value(key: Union[DictKeyType, int]) -> Any:
  199. try:
  200. node = conf._get_child(key, throw_on_missing_value=throw_on_missing)
  201. except MissingMandatoryValue as e:
  202. conf._format_and_raise(key=key, value=None, cause=e)
  203. assert isinstance(node, Node)
  204. if resolve:
  205. try:
  206. node = node._dereference_node()
  207. except InterpolationResolutionError as e:
  208. conf._format_and_raise(key=key, value=None, cause=e)
  209. if isinstance(node, Container):
  210. value = BaseContainer._to_content(
  211. node,
  212. resolve=resolve,
  213. throw_on_missing=throw_on_missing,
  214. enum_to_str=enum_to_str,
  215. structured_config_mode=structured_config_mode,
  216. )
  217. else:
  218. value = convert(node)
  219. return value
  220. if conf._is_none():
  221. return None
  222. elif conf._is_missing():
  223. if throw_on_missing:
  224. conf._format_and_raise(
  225. key=None,
  226. value=None,
  227. cause=MissingMandatoryValue("Missing mandatory value"),
  228. )
  229. else:
  230. return MISSING
  231. elif not resolve and conf._is_interpolation():
  232. inter = conf._value()
  233. assert isinstance(inter, str)
  234. return inter
  235. if resolve:
  236. _conf = conf._dereference_node()
  237. assert isinstance(_conf, Container)
  238. conf = _conf
  239. if isinstance(conf, DictConfig):
  240. if (
  241. conf._metadata.object_type not in (dict, None)
  242. and structured_config_mode == SCMode.DICT_CONFIG
  243. ):
  244. return conf
  245. if structured_config_mode == SCMode.INSTANTIATE and is_structured_config(
  246. conf._metadata.object_type
  247. ):
  248. return conf._to_object()
  249. retdict: Dict[DictKeyType, Any] = {}
  250. for key in conf.keys():
  251. value = get_node_value(key)
  252. if enum_to_str and isinstance(key, Enum):
  253. key = f"{key.name}"
  254. retdict[key] = value
  255. return retdict
  256. elif isinstance(conf, ListConfig):
  257. retlist: List[Any] = []
  258. for index in range(len(conf)):
  259. item = get_node_value(index)
  260. retlist.append(item)
  261. return retlist
  262. assert False
  263. @staticmethod
  264. def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None:
  265. """merge src into dest and return a new copy, does not modified input"""
  266. from omegaconf import AnyNode, DictConfig, ValueNode
  267. assert isinstance(dest, DictConfig)
  268. assert isinstance(src, DictConfig)
  269. src_type = src._metadata.object_type
  270. src_ref_type = get_type_hint(src)
  271. assert src_ref_type is not None
  272. # If source DictConfig is:
  273. # - None => set the destination DictConfig to None
  274. # - an interpolation => set the destination DictConfig to be the same interpolation
  275. if src._is_none() or src._is_interpolation():
  276. dest._set_value(src._value())
  277. _update_types(node=dest, ref_type=src_ref_type, object_type=src_type)
  278. return
  279. dest._validate_merge(value=src)
  280. def expand(node: Container) -> None:
  281. rt = node._metadata.ref_type
  282. val: Any
  283. if rt is not Any:
  284. if is_dict_annotation(rt):
  285. val = {}
  286. elif is_list_annotation(rt) or is_tuple_annotation(rt):
  287. val = []
  288. else:
  289. val = rt
  290. elif isinstance(node, DictConfig):
  291. val = {}
  292. else:
  293. assert False
  294. node._set_value(val)
  295. if (
  296. src._is_missing()
  297. and not dest._is_missing()
  298. and is_structured_config(src_ref_type)
  299. ):
  300. # Replace `src` with a prototype of its corresponding structured config
  301. # whose fields are all missing (to avoid overwriting fields in `dest`).
  302. assert src_type is None # src missing, so src's object_type should be None
  303. src_type = src_ref_type
  304. src = _create_structured_with_missing_fields(
  305. ref_type=src_ref_type, object_type=src_type
  306. )
  307. if (dest._is_interpolation() or dest._is_missing()) and not src._is_missing():
  308. expand(dest)
  309. src_items = list(src) if not src._is_missing() else []
  310. for key in src_items:
  311. src_node = src._get_node(key, validate_access=False)
  312. dest_node = dest._get_node(key, validate_access=False)
  313. assert isinstance(src_node, Node)
  314. assert dest_node is None or isinstance(dest_node, Node)
  315. src_value = _get_value(src_node)
  316. src_vk = get_value_kind(src_node)
  317. src_node_missing = src_vk is ValueKind.MANDATORY_MISSING
  318. if isinstance(dest_node, DictConfig):
  319. dest_node._validate_merge(value=src_node)
  320. if (
  321. isinstance(dest_node, Container)
  322. and dest_node._is_none()
  323. and not src_node_missing
  324. and not _is_none(src_node, resolve=True)
  325. ):
  326. expand(dest_node)
  327. if dest_node is not None and dest_node._is_interpolation():
  328. target_node = dest_node._maybe_dereference_node()
  329. if isinstance(target_node, Container):
  330. dest[key] = target_node
  331. dest_node = dest._get_node(key)
  332. is_optional, et = _resolve_optional(dest._metadata.element_type)
  333. if dest_node is None and is_structured_config(et) and not src_node_missing:
  334. # merging into a new node. Use element_type as a base
  335. dest[key] = DictConfig(
  336. et, parent=dest, ref_type=et, is_optional=is_optional
  337. )
  338. dest_node = dest._get_node(key)
  339. if dest_node is not None:
  340. if isinstance(dest_node, BaseContainer):
  341. if isinstance(src_node, BaseContainer):
  342. dest_node._merge_with(src_node)
  343. elif not src_node_missing:
  344. dest.__setitem__(key, src_node)
  345. else:
  346. if isinstance(src_node, BaseContainer):
  347. dest.__setitem__(key, src_node)
  348. else:
  349. assert isinstance(dest_node, (ValueNode, UnionNode))
  350. assert isinstance(src_node, (ValueNode, UnionNode))
  351. try:
  352. if isinstance(dest_node, AnyNode):
  353. if src_node_missing:
  354. node = copy.copy(src_node)
  355. # if src node is missing, use the value from the dest_node,
  356. # but validate it against the type of the src node before assigment
  357. node._set_value(dest_node._value())
  358. else:
  359. node = src_node
  360. dest.__setitem__(key, node)
  361. else:
  362. if not src_node_missing:
  363. dest_node._set_value(src_value)
  364. except (ValidationError, ReadonlyConfigError) as e:
  365. dest._format_and_raise(key=key, value=src_value, cause=e)
  366. else:
  367. from omegaconf import open_dict
  368. if is_structured_config(src_type):
  369. # verified to be compatible above in _validate_merge
  370. with open_dict(dest):
  371. dest[key] = src._get_node(key)
  372. else:
  373. dest[key] = src._get_node(key)
  374. _update_types(node=dest, ref_type=src_ref_type, object_type=src_type)
  375. # explicit flags on the source config are replacing the flag values in the destination
  376. flags = src._metadata.flags
  377. assert flags is not None
  378. for flag, value in flags.items():
  379. if value is not None:
  380. dest._set_flag(flag, value)
  381. @staticmethod
  382. def _list_merge(dest: Any, src: Any) -> None:
  383. from omegaconf import DictConfig, ListConfig, OmegaConf
  384. assert isinstance(dest, ListConfig)
  385. assert isinstance(src, ListConfig)
  386. if src._is_none():
  387. dest._set_value(None)
  388. elif src._is_missing():
  389. # do not change dest if src is MISSING.
  390. if dest._metadata.element_type is Any:
  391. dest._metadata.element_type = src._metadata.element_type
  392. elif src._is_interpolation():
  393. dest._set_value(src._value())
  394. else:
  395. temp_target = ListConfig(content=[], parent=dest._get_parent())
  396. temp_target.__dict__["_metadata"] = copy.deepcopy(
  397. dest.__dict__["_metadata"]
  398. )
  399. is_optional, et = _resolve_optional(dest._metadata.element_type)
  400. if is_structured_config(et):
  401. prototype = DictConfig(et, ref_type=et, is_optional=is_optional)
  402. for item in src._iter_ex(resolve=False):
  403. if isinstance(item, DictConfig):
  404. item = OmegaConf.merge(prototype, item)
  405. temp_target.append(item)
  406. else:
  407. for item in src._iter_ex(resolve=False):
  408. temp_target.append(item)
  409. dest.__dict__["_content"] = temp_target.__dict__["_content"]
  410. # explicit flags on the source config are replacing the flag values in the destination
  411. flags = src._metadata.flags
  412. assert flags is not None
  413. for flag, value in flags.items():
  414. if value is not None:
  415. dest._set_flag(flag, value)
  416. def merge_with(
  417. self,
  418. *others: Union[
  419. "BaseContainer", Dict[str, Any], List[Any], Tuple[Any, ...], Any
  420. ],
  421. ) -> None:
  422. try:
  423. self._merge_with(*others)
  424. except Exception as e:
  425. self._format_and_raise(key=None, value=None, cause=e)
  426. def _merge_with(
  427. self,
  428. *others: Union[
  429. "BaseContainer", Dict[str, Any], List[Any], Tuple[Any, ...], Any
  430. ],
  431. ) -> None:
  432. from .dictconfig import DictConfig
  433. from .listconfig import ListConfig
  434. """merge a list of other Config objects into this one, overriding as needed"""
  435. for other in others:
  436. if other is None:
  437. raise ValueError("Cannot merge with a None config")
  438. my_flags = {}
  439. if self._get_flag("allow_objects") is True:
  440. my_flags = {"allow_objects": True}
  441. other = _ensure_container(other, flags=my_flags)
  442. if isinstance(self, DictConfig) and isinstance(other, DictConfig):
  443. BaseContainer._map_merge(self, other)
  444. elif isinstance(self, ListConfig) and isinstance(other, ListConfig):
  445. BaseContainer._list_merge(self, other)
  446. else:
  447. raise TypeError("Cannot merge DictConfig with ListConfig")
  448. # recursively correct the parent hierarchy after the merge
  449. self._re_parent()
  450. # noinspection PyProtectedMember
  451. def _set_item_impl(self, key: Any, value: Any) -> None:
  452. """
  453. Changes the value of the node key with the desired value. If the node key doesn't
  454. exist it creates a new one.
  455. """
  456. from .nodes import AnyNode, ValueNode
  457. if isinstance(value, Node):
  458. do_deepcopy = not self._get_flag("no_deepcopy_set_nodes")
  459. if not do_deepcopy and isinstance(value, Box):
  460. # if value is from the same config, perform a deepcopy no matter what.
  461. if self._get_root() is value._get_root():
  462. do_deepcopy = True
  463. if do_deepcopy:
  464. value = copy.deepcopy(value)
  465. value._set_parent(None)
  466. try:
  467. old = value._key()
  468. value._set_key(key)
  469. self._validate_set(key, value)
  470. finally:
  471. value._set_key(old)
  472. else:
  473. self._validate_set(key, value)
  474. if self._get_flag("readonly"):
  475. raise ReadonlyConfigError("Cannot change read-only config container")
  476. input_is_node = isinstance(value, Node)
  477. target_node_ref = self._get_node(key)
  478. assert target_node_ref is None or isinstance(target_node_ref, Node)
  479. input_is_typed_vnode = isinstance(value, ValueNode) and not isinstance(
  480. value, AnyNode
  481. )
  482. def get_target_type_hint(val: Any) -> Any:
  483. if not is_structured_config(val):
  484. type_hint = self._metadata.element_type
  485. else:
  486. target = self._get_node(key)
  487. if target is None:
  488. type_hint = self._metadata.element_type
  489. else:
  490. assert isinstance(target, Node)
  491. type_hint = target._metadata.type_hint
  492. return type_hint
  493. target_type_hint = get_target_type_hint(value)
  494. _, target_ref_type = _resolve_optional(target_type_hint)
  495. def assign(value_key: Any, val: Node) -> None:
  496. assert val._get_parent() is None
  497. v = val
  498. v._set_parent(self)
  499. v._set_key(value_key)
  500. _deep_update_type_hint(node=v, type_hint=self._metadata.element_type)
  501. self.__dict__["_content"][value_key] = v
  502. if input_is_typed_vnode and not is_union_annotation(target_ref_type):
  503. assign(key, value)
  504. else:
  505. # input is not a ValueNode, can be primitive or box
  506. special_value = _is_special(value)
  507. # We use the `Node._set_value` method if the target node exists and:
  508. # 1. the target has an explicit ref_type, or
  509. # 2. the target is an AnyNode and the input is a primitive type.
  510. should_set_value = target_node_ref is not None and (
  511. target_node_ref._has_ref_type()
  512. or (
  513. isinstance(target_node_ref, AnyNode)
  514. and is_primitive_type_annotation(value)
  515. )
  516. )
  517. if should_set_value:
  518. if special_value and isinstance(value, Node):
  519. value = value._value()
  520. self.__dict__["_content"][key]._set_value(value)
  521. elif input_is_node:
  522. if (
  523. special_value
  524. and (
  525. is_container_annotation(target_ref_type)
  526. or is_structured_config(target_ref_type)
  527. )
  528. or is_primitive_type_annotation(target_ref_type)
  529. or is_union_annotation(target_ref_type)
  530. ):
  531. value = _get_value(value)
  532. self._wrap_value_and_set(key, value, target_type_hint)
  533. else:
  534. assign(key, value)
  535. else:
  536. self._wrap_value_and_set(key, value, target_type_hint)
  537. def _wrap_value_and_set(self, key: Any, val: Any, type_hint: Any) -> None:
  538. from omegaconf.omegaconf import _maybe_wrap
  539. is_optional, ref_type = _resolve_optional(type_hint)
  540. try:
  541. wrapped = _maybe_wrap(
  542. ref_type=ref_type,
  543. key=key,
  544. value=val,
  545. is_optional=is_optional,
  546. parent=self,
  547. )
  548. except ValidationError as e:
  549. self._format_and_raise(key=key, value=val, cause=e)
  550. self.__dict__["_content"][key] = wrapped
  551. @staticmethod
  552. def _item_eq(
  553. c1: Container,
  554. k1: Union[DictKeyType, int],
  555. c2: Container,
  556. k2: Union[DictKeyType, int],
  557. ) -> bool:
  558. v1 = c1._get_child(k1)
  559. v2 = c2._get_child(k2)
  560. assert v1 is not None and v2 is not None
  561. assert isinstance(v1, Node)
  562. assert isinstance(v2, Node)
  563. if v1._is_none() and v2._is_none():
  564. return True
  565. if v1._is_missing() and v2._is_missing():
  566. return True
  567. v1_inter = v1._is_interpolation()
  568. v2_inter = v2._is_interpolation()
  569. dv1: Optional[Node] = v1
  570. dv2: Optional[Node] = v2
  571. if v1_inter:
  572. dv1 = v1._maybe_dereference_node()
  573. if v2_inter:
  574. dv2 = v2._maybe_dereference_node()
  575. if v1_inter and v2_inter:
  576. if dv1 is None or dv2 is None:
  577. return v1 == v2
  578. else:
  579. # both are not none, if both are containers compare as container
  580. if isinstance(dv1, Container) and isinstance(dv2, Container):
  581. if dv1 != dv2:
  582. return False
  583. dv1 = _get_value(dv1)
  584. dv2 = _get_value(dv2)
  585. return dv1 == dv2
  586. elif not v1_inter and not v2_inter:
  587. v1 = _get_value(v1)
  588. v2 = _get_value(v2)
  589. ret = v1 == v2
  590. assert isinstance(ret, bool)
  591. return ret
  592. else:
  593. dv1 = _get_value(dv1)
  594. dv2 = _get_value(dv2)
  595. ret = dv1 == dv2
  596. assert isinstance(ret, bool)
  597. return ret
  598. def _is_optional(self) -> bool:
  599. return self.__dict__["_metadata"].optional is True
  600. def _is_interpolation(self) -> bool:
  601. return _is_interpolation(self.__dict__["_content"])
  602. @abstractmethod
  603. def _validate_get(self, key: Any, value: Any = None) -> None:
  604. ...
  605. @abstractmethod
  606. def _validate_set(self, key: Any, value: Any) -> None:
  607. ...
  608. def _value(self) -> Any:
  609. return self.__dict__["_content"]
  610. def _get_full_key(self, key: Union[DictKeyType, int, slice, None]) -> str:
  611. from .listconfig import ListConfig
  612. from .omegaconf import _select_one
  613. if not isinstance(key, (int, str, Enum, float, bool, slice, bytes, type(None))):
  614. return ""
  615. def _slice_to_str(x: slice) -> str:
  616. if x.step is not None:
  617. return f"{x.start}:{x.stop}:{x.step}"
  618. else:
  619. return f"{x.start}:{x.stop}"
  620. def prepand(
  621. full_key: str,
  622. parent_type: Any,
  623. cur_type: Any,
  624. key: Optional[Union[DictKeyType, int, slice]],
  625. ) -> str:
  626. if key is None:
  627. return full_key
  628. if isinstance(key, slice):
  629. key = _slice_to_str(key)
  630. elif isinstance(key, Enum):
  631. key = key.name
  632. else:
  633. key = str(key)
  634. assert isinstance(key, str)
  635. if issubclass(parent_type, ListConfig):
  636. if full_key != "":
  637. if issubclass(cur_type, ListConfig):
  638. full_key = f"[{key}]{full_key}"
  639. else:
  640. full_key = f"[{key}].{full_key}"
  641. else:
  642. full_key = f"[{key}]"
  643. else:
  644. if full_key == "":
  645. full_key = key
  646. else:
  647. if issubclass(cur_type, ListConfig):
  648. full_key = f"{key}{full_key}"
  649. else:
  650. full_key = f"{key}.{full_key}"
  651. return full_key
  652. if key is not None and key != "":
  653. assert isinstance(self, Container)
  654. cur, _ = _select_one(
  655. c=self, key=str(key), throw_on_missing=False, throw_on_type_error=False
  656. )
  657. if cur is None:
  658. cur = self
  659. full_key = prepand("", type(cur), None, key)
  660. if cur._key() is not None:
  661. full_key = prepand(
  662. full_key, type(cur._get_parent()), type(cur), cur._key()
  663. )
  664. else:
  665. full_key = prepand("", type(cur._get_parent()), type(cur), cur._key())
  666. else:
  667. cur = self
  668. if cur._key() is None:
  669. return ""
  670. full_key = self._key()
  671. assert cur is not None
  672. memo = {id(cur)} # remember already visited nodes so as to detect cycles
  673. while cur._get_parent() is not None:
  674. cur = cur._get_parent()
  675. if id(cur) in memo:
  676. raise ConfigCycleDetectedException(
  677. f"Cycle when iterating over parents of key `{key!s}`"
  678. )
  679. memo.add(id(cur))
  680. assert cur is not None
  681. if cur._key() is not None:
  682. full_key = prepand(
  683. full_key, type(cur._get_parent()), type(cur), cur._key()
  684. )
  685. return full_key
  686. def _create_structured_with_missing_fields(
  687. ref_type: type, object_type: Optional[type] = None
  688. ) -> "DictConfig":
  689. from . import MISSING, DictConfig
  690. cfg_data = get_structured_config_data(ref_type)
  691. for v in cfg_data.values():
  692. v._set_value(MISSING)
  693. cfg = DictConfig(cfg_data)
  694. cfg._metadata.optional, cfg._metadata.ref_type = _resolve_optional(ref_type)
  695. cfg._metadata.object_type = object_type
  696. return cfg
  697. def _update_types(node: Node, ref_type: Any, object_type: Optional[type]) -> None:
  698. if object_type is not None and not is_primitive_dict(object_type):
  699. node._metadata.object_type = object_type
  700. if node._metadata.ref_type is Any:
  701. _deep_update_type_hint(node, ref_type)
  702. def _deep_update_type_hint(node: Node, type_hint: Any) -> None:
  703. """Ensure node is compatible with type_hint, mutating if necessary."""
  704. from omegaconf import DictConfig, ListConfig
  705. from ._utils import get_dict_key_value_types, get_list_element_type
  706. if type_hint is Any:
  707. return
  708. _shallow_validate_type_hint(node, type_hint)
  709. new_is_optional, new_ref_type = _resolve_optional(type_hint)
  710. node._metadata.ref_type = new_ref_type
  711. node._metadata.optional = new_is_optional
  712. if is_list_annotation(new_ref_type) and isinstance(node, ListConfig):
  713. new_element_type = get_list_element_type(new_ref_type)
  714. node._metadata.element_type = new_element_type
  715. if not _is_special(node):
  716. for i in range(len(node)):
  717. _deep_update_subnode(node, i, new_element_type)
  718. if is_dict_annotation(new_ref_type) and isinstance(node, DictConfig):
  719. new_key_type, new_element_type = get_dict_key_value_types(new_ref_type)
  720. node._metadata.key_type = new_key_type
  721. node._metadata.element_type = new_element_type
  722. if not _is_special(node):
  723. for key in node:
  724. if new_key_type is not Any and not isinstance(key, new_key_type):
  725. raise KeyValidationError(
  726. f"Key {key!r} ({type(key).__name__}) is incompatible"
  727. + f" with key type hint '{new_key_type.__name__}'"
  728. )
  729. _deep_update_subnode(node, key, new_element_type)
  730. def _deep_update_subnode(node: BaseContainer, key: Any, value_type_hint: Any) -> None:
  731. """Get node[key] and ensure it is compatible with value_type_hint, mutating if necessary."""
  732. subnode = node._get_node(key)
  733. assert isinstance(subnode, Node)
  734. if _is_special(subnode):
  735. # Ensure special values are wrapped in a Node subclass that
  736. # is compatible with the type hint.
  737. node._wrap_value_and_set(key, subnode._value(), value_type_hint)
  738. subnode = node._get_node(key)
  739. assert isinstance(subnode, Node)
  740. _deep_update_type_hint(subnode, value_type_hint)
  741. def _shallow_validate_type_hint(node: Node, type_hint: Any) -> None:
  742. """Error if node's type, content and metadata are not compatible with type_hint."""
  743. from omegaconf import DictConfig, ListConfig, ValueNode
  744. is_optional, ref_type = _resolve_optional(type_hint)
  745. vk = get_value_kind(node)
  746. if node._is_none():
  747. if not is_optional:
  748. value = _get_value(node)
  749. raise ValidationError(
  750. f"Value {value!r} ({type(value).__name__})"
  751. + f" is incompatible with type hint '{ref_type.__name__}'"
  752. )
  753. return
  754. elif vk in (ValueKind.MANDATORY_MISSING, ValueKind.INTERPOLATION):
  755. return
  756. elif vk == ValueKind.VALUE:
  757. if is_primitive_type_annotation(ref_type) and isinstance(node, ValueNode):
  758. value = node._value()
  759. if not isinstance(value, ref_type):
  760. raise ValidationError(
  761. f"Value {value!r} ({type(value).__name__})"
  762. + f" is incompatible with type hint '{ref_type.__name__}'"
  763. )
  764. elif is_structured_config(ref_type) and isinstance(node, DictConfig):
  765. return
  766. elif is_dict_annotation(ref_type) and isinstance(node, DictConfig):
  767. return
  768. elif is_list_annotation(ref_type) and isinstance(node, ListConfig):
  769. return
  770. else:
  771. if isinstance(node, ValueNode):
  772. value = node._value()
  773. raise ValidationError(
  774. f"Value {value!r} ({type(value).__name__})"
  775. + f" is incompatible with type hint '{ref_type}'"
  776. )
  777. else:
  778. raise ValidationError(
  779. f"'{type(node).__name__}' is incompatible"
  780. + f" with type hint '{ref_type}'"
  781. )
  782. else:
  783. assert False