| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776 |
- import copy
- from enum import Enum
- from typing import (
- Any,
- Dict,
- ItemsView,
- Iterable,
- Iterator,
- KeysView,
- List,
- MutableMapping,
- Optional,
- Sequence,
- Tuple,
- Type,
- Union,
- )
- from ._utils import (
- _DEFAULT_MARKER_,
- ValueKind,
- _get_value,
- _is_interpolation,
- _is_missing_literal,
- _is_missing_value,
- _is_none,
- _resolve_optional,
- _valid_dict_key_annotation_type,
- format_and_raise,
- get_structured_config_data,
- get_structured_config_init_field_names,
- get_type_of,
- get_value_kind,
- is_container_annotation,
- is_dict,
- is_primitive_dict,
- is_structured_config,
- is_structured_config_frozen,
- type_str,
- )
- from .base import Box, Container, ContainerMetadata, DictKeyType, Node
- from .basecontainer import BaseContainer
- from .errors import (
- ConfigAttributeError,
- ConfigKeyError,
- ConfigTypeError,
- InterpolationResolutionError,
- KeyValidationError,
- MissingMandatoryValue,
- OmegaConfBaseException,
- ReadonlyConfigError,
- ValidationError,
- )
- from .nodes import EnumNode, ValueNode
- class DictConfig(BaseContainer, MutableMapping[Any, Any]):
- _metadata: ContainerMetadata
- _content: Union[Dict[DictKeyType, Node], None, str]
- def __init__(
- self,
- content: Union[Dict[DictKeyType, Any], "DictConfig", Any],
- key: Any = None,
- parent: Optional[Box] = None,
- ref_type: Union[Any, Type[Any]] = Any,
- key_type: Union[Any, Type[Any]] = Any,
- element_type: Union[Any, Type[Any]] = Any,
- is_optional: bool = True,
- flags: Optional[Dict[str, bool]] = None,
- ) -> None:
- try:
- if isinstance(content, DictConfig):
- if flags is None:
- flags = content._metadata.flags
- super().__init__(
- parent=parent,
- metadata=ContainerMetadata(
- key=key,
- optional=is_optional,
- ref_type=ref_type,
- object_type=dict,
- key_type=key_type,
- element_type=element_type,
- flags=flags,
- ),
- )
- if not _valid_dict_key_annotation_type(key_type):
- raise KeyValidationError(f"Unsupported key type {key_type}")
- if is_structured_config(content) or is_structured_config(ref_type):
- self._set_value(content, flags=flags)
- if is_structured_config_frozen(content) or is_structured_config_frozen(
- ref_type
- ):
- self._set_flag("readonly", True)
- else:
- if isinstance(content, DictConfig):
- metadata = copy.deepcopy(content._metadata)
- metadata.key = key
- metadata.ref_type = ref_type
- metadata.optional = is_optional
- metadata.element_type = element_type
- metadata.key_type = key_type
- self.__dict__["_metadata"] = metadata
- self._set_value(content, flags=flags)
- except Exception as ex:
- format_and_raise(node=None, key=key, value=None, cause=ex, msg=str(ex))
- def __deepcopy__(self, memo: Dict[int, Any]) -> "DictConfig":
- res = DictConfig(None)
- res.__dict__["_metadata"] = copy.deepcopy(self.__dict__["_metadata"], memo=memo)
- res.__dict__["_flags_cache"] = copy.deepcopy(
- self.__dict__["_flags_cache"], memo=memo
- )
- src_content = self.__dict__["_content"]
- if isinstance(src_content, dict):
- content_copy = {}
- for k, v in src_content.items():
- old_parent = v.__dict__["_parent"]
- try:
- v.__dict__["_parent"] = None
- vc = copy.deepcopy(v, memo=memo)
- vc.__dict__["_parent"] = res
- content_copy[k] = vc
- finally:
- v.__dict__["_parent"] = old_parent
- else:
- # None and strings can be assigned as is
- content_copy = src_content
- res.__dict__["_content"] = content_copy
- # parent is retained, but not copied
- res.__dict__["_parent"] = self.__dict__["_parent"]
- return res
- def copy(self) -> "DictConfig":
- return copy.copy(self)
- def _is_typed(self) -> bool:
- return self._metadata.object_type not in (Any, None) and not is_dict(
- self._metadata.object_type
- )
- def _validate_get(self, key: Any, value: Any = None) -> None:
- is_typed = self._is_typed()
- is_struct = self._get_flag("struct") is True
- if key not in self.__dict__["_content"]:
- if is_typed:
- # do not raise an exception if struct is explicitly set to False
- if self._get_node_flag("struct") is False:
- return
- if is_typed or is_struct:
- if is_typed:
- assert self._metadata.object_type not in (dict, None)
- msg = f"Key '{key}' not in '{self._metadata.object_type.__name__}'"
- else:
- msg = f"Key '{key}' is not in struct"
- self._format_and_raise(
- key=key, value=value, cause=ConfigAttributeError(msg)
- )
- def _validate_set(self, key: Any, value: Any) -> None:
- from omegaconf import OmegaConf
- vk = get_value_kind(value)
- if vk == ValueKind.INTERPOLATION:
- return
- if _is_none(value):
- self._validate_non_optional(key, value)
- return
- if vk == ValueKind.MANDATORY_MISSING or value is None:
- return
- target = self._get_node(key) if key is not None else self
- target_has_ref_type = isinstance(
- target, DictConfig
- ) and target._metadata.ref_type not in (Any, dict)
- is_valid_target = target is None or not target_has_ref_type
- if is_valid_target:
- return
- assert isinstance(target, Node)
- target_type = target._metadata.ref_type
- value_type = OmegaConf.get_type(value)
- if is_dict(value_type) and is_dict(target_type):
- return
- if is_container_annotation(target_type) and not is_container_annotation(
- value_type
- ):
- raise ValidationError(
- f"Cannot assign {type_str(value_type)} to {type_str(target_type)}"
- )
- if target_type is not None and value_type is not None:
- origin = getattr(target_type, "__origin__", target_type)
- if not issubclass(value_type, origin):
- self._raise_invalid_value(value, value_type, target_type)
- def _validate_merge(self, value: Any) -> None:
- from omegaconf import OmegaConf
- dest = self
- src = value
- self._validate_non_optional(None, src)
- dest_obj_type = OmegaConf.get_type(dest)
- src_obj_type = OmegaConf.get_type(src)
- if dest._is_missing() and src._metadata.object_type not in (dict, None):
- self._validate_set(key=None, value=_get_value(src))
- if src._is_missing():
- return
- validation_error = (
- dest_obj_type is not None
- and src_obj_type is not None
- and is_structured_config(dest_obj_type)
- and not src._is_none()
- and not is_dict(src_obj_type)
- and not issubclass(src_obj_type, dest_obj_type)
- )
- if validation_error:
- msg = (
- f"Merge error: {type_str(src_obj_type)} is not a "
- f"subclass of {type_str(dest_obj_type)}. value: {src}"
- )
- raise ValidationError(msg)
- def _validate_non_optional(self, key: Optional[DictKeyType], value: Any) -> None:
- if _is_none(value, resolve=True, throw_on_resolution_failure=False):
- if key is not None:
- child = self._get_node(key)
- if child is not None:
- assert isinstance(child, Node)
- field_is_optional = child._is_optional()
- else:
- field_is_optional, _ = _resolve_optional(
- self._metadata.element_type
- )
- else:
- field_is_optional = self._is_optional()
- if not field_is_optional:
- self._format_and_raise(
- key=key,
- value=value,
- cause=ValidationError("field '$FULL_KEY' is not Optional"),
- )
- def _raise_invalid_value(
- self, value: Any, value_type: Any, target_type: Any
- ) -> None:
- assert value_type is not None
- assert target_type is not None
- msg = (
- f"Invalid type assigned: {type_str(value_type)} is not a "
- f"subclass of {type_str(target_type)}. value: {value}"
- )
- raise ValidationError(msg)
- def _validate_and_normalize_key(self, key: Any) -> DictKeyType:
- return self._s_validate_and_normalize_key(self._metadata.key_type, key)
- def _s_validate_and_normalize_key(self, key_type: Any, key: Any) -> DictKeyType:
- if key_type is Any:
- for t in DictKeyType.__args__: # type: ignore
- if isinstance(key, t):
- return key # type: ignore
- raise KeyValidationError("Incompatible key type '$KEY_TYPE'")
- elif key_type is bool and key in [0, 1]:
- # Python treats True as 1 and False as 0 when used as dict keys
- # assert hash(0) == hash(False)
- # assert hash(1) == hash(True)
- return bool(key)
- elif key_type in (str, bytes, int, float, bool): # primitive type
- if not isinstance(key, key_type):
- raise KeyValidationError(
- f"Key $KEY ($KEY_TYPE) is incompatible with ({key_type.__name__})"
- )
- return key # type: ignore
- elif issubclass(key_type, Enum):
- try:
- return EnumNode.validate_and_convert_to_enum(key_type, key)
- except ValidationError:
- valid = ", ".join([x for x in key_type.__members__.keys()])
- raise KeyValidationError(
- f"Key '$KEY' is incompatible with the enum type '{key_type.__name__}', valid: [{valid}]"
- )
- else:
- assert False, f"Unsupported key type {key_type}"
- def __setitem__(self, key: DictKeyType, value: Any) -> None:
- try:
- self.__set_impl(key=key, value=value)
- except AttributeError as e:
- self._format_and_raise(
- key=key, value=value, type_override=ConfigKeyError, cause=e
- )
- except Exception as e:
- self._format_and_raise(key=key, value=value, cause=e)
- def __set_impl(self, key: DictKeyType, value: Any) -> None:
- key = self._validate_and_normalize_key(key)
- self._set_item_impl(key, value)
- # hide content while inspecting in debugger
- def __dir__(self) -> Iterable[str]:
- if self._is_missing() or self._is_none():
- return []
- return self.__dict__["_content"].keys() # type: ignore
- def __setattr__(self, key: str, value: Any) -> None:
- """
- Allow assigning attributes to DictConfig
- :param key:
- :param value:
- :return:
- """
- try:
- self.__set_impl(key, value)
- except Exception as e:
- if isinstance(e, OmegaConfBaseException) and e._initialized:
- raise e
- self._format_and_raise(key=key, value=value, cause=e)
- assert False
- def __getattr__(self, key: str) -> Any:
- """
- Allow accessing dictionary values as attributes
- :param key:
- :return:
- """
- if key == "__name__":
- raise AttributeError()
- try:
- return self._get_impl(
- key=key, default_value=_DEFAULT_MARKER_, validate_key=False
- )
- except ConfigKeyError as e:
- self._format_and_raise(
- key=key, value=None, cause=e, type_override=ConfigAttributeError
- )
- except Exception as e:
- self._format_and_raise(key=key, value=None, cause=e)
- def __getitem__(self, key: DictKeyType) -> Any:
- """
- Allow map style access
- :param key:
- :return:
- """
- try:
- return self._get_impl(key=key, default_value=_DEFAULT_MARKER_)
- except AttributeError as e:
- self._format_and_raise(
- key=key, value=None, cause=e, type_override=ConfigKeyError
- )
- except Exception as e:
- self._format_and_raise(key=key, value=None, cause=e)
- def __delattr__(self, key: str) -> None:
- """
- Allow deleting dictionary values as attributes
- :param key:
- :return:
- """
- if self._get_flag("readonly"):
- self._format_and_raise(
- key=key,
- value=None,
- cause=ReadonlyConfigError(
- "DictConfig in read-only mode does not support deletion"
- ),
- )
- try:
- del self.__dict__["_content"][key]
- except KeyError:
- msg = "Attribute not found: '$KEY'"
- self._format_and_raise(key=key, value=None, cause=ConfigAttributeError(msg))
- def __delitem__(self, key: DictKeyType) -> None:
- key = self._validate_and_normalize_key(key)
- if self._get_flag("readonly"):
- self._format_and_raise(
- key=key,
- value=None,
- cause=ReadonlyConfigError(
- "DictConfig in read-only mode does not support deletion"
- ),
- )
- if self._get_flag("struct"):
- self._format_and_raise(
- key=key,
- value=None,
- cause=ConfigTypeError(
- "DictConfig in struct mode does not support deletion"
- ),
- )
- if self._is_typed() and self._get_node_flag("struct") is not False:
- self._format_and_raise(
- key=key,
- value=None,
- cause=ConfigTypeError(
- f"{type_str(self._metadata.object_type)} (DictConfig) does not support deletion"
- ),
- )
- try:
- del self.__dict__["_content"][key]
- except KeyError:
- msg = "Key not found: '$KEY'"
- self._format_and_raise(key=key, value=None, cause=ConfigKeyError(msg))
- def get(self, key: DictKeyType, default_value: Any = None) -> Any:
- """Return the value for `key` if `key` is in the dictionary, else
- `default_value` (defaulting to `None`)."""
- try:
- return self._get_impl(key=key, default_value=default_value)
- except KeyValidationError as e:
- self._format_and_raise(key=key, value=None, cause=e)
- def _get_impl(
- self, key: DictKeyType, default_value: Any, validate_key: bool = True
- ) -> Any:
- try:
- node = self._get_child(
- key=key, throw_on_missing_key=True, validate_key=validate_key
- )
- except (ConfigAttributeError, ConfigKeyError):
- if default_value is not _DEFAULT_MARKER_:
- return default_value
- else:
- raise
- assert isinstance(node, Node)
- return self._resolve_with_default(
- key=key, value=node, default_value=default_value
- )
- def _get_node(
- self,
- key: DictKeyType,
- validate_access: bool = True,
- validate_key: bool = True,
- throw_on_missing_value: bool = False,
- throw_on_missing_key: bool = False,
- ) -> Optional[Node]:
- try:
- key = self._validate_and_normalize_key(key)
- except KeyValidationError:
- if validate_access and validate_key:
- raise
- else:
- if throw_on_missing_key:
- raise ConfigAttributeError
- else:
- return None
- if validate_access:
- self._validate_get(key)
- value: Optional[Node] = self.__dict__["_content"].get(key)
- if value is None:
- if throw_on_missing_key:
- raise ConfigKeyError(f"Missing key {key!s}")
- elif throw_on_missing_value and value._is_missing():
- raise MissingMandatoryValue("Missing mandatory value: $KEY")
- return value
- def pop(self, key: DictKeyType, default: Any = _DEFAULT_MARKER_) -> Any:
- try:
- if self._get_flag("readonly"):
- raise ReadonlyConfigError("Cannot pop from read-only node")
- if self._get_flag("struct"):
- raise ConfigTypeError("DictConfig in struct mode does not support pop")
- if self._is_typed() and self._get_node_flag("struct") is not False:
- raise ConfigTypeError(
- f"{type_str(self._metadata.object_type)} (DictConfig) does not support pop"
- )
- key = self._validate_and_normalize_key(key)
- node = self._get_child(key=key, validate_access=False)
- if node is not None:
- assert isinstance(node, Node)
- value = self._resolve_with_default(
- key=key, value=node, default_value=default
- )
- del self[key]
- return value
- else:
- if default is not _DEFAULT_MARKER_:
- return default
- else:
- full = self._get_full_key(key=key)
- if full != key:
- raise ConfigKeyError(
- f"Key not found: '{key!s}' (path: '{full}')"
- )
- else:
- raise ConfigKeyError(f"Key not found: '{key!s}'")
- except Exception as e:
- self._format_and_raise(key=key, value=None, cause=e)
- def keys(self) -> KeysView[DictKeyType]:
- if self._is_missing() or self._is_interpolation() or self._is_none():
- return {}.keys()
- ret = self.__dict__["_content"].keys()
- assert isinstance(ret, KeysView)
- return ret
- def __contains__(self, key: object) -> bool:
- """
- A key is contained in a DictConfig if there is an associated value and
- it is not a mandatory missing value ('???').
- :param key:
- :return:
- """
- try:
- key = self._validate_and_normalize_key(key)
- except KeyValidationError:
- return False
- try:
- node = self._get_child(key)
- assert node is None or isinstance(node, Node)
- except (KeyError, AttributeError):
- node = None
- if node is None:
- return False
- else:
- try:
- self._resolve_with_default(key=key, value=node)
- return True
- except InterpolationResolutionError:
- # Interpolations that fail count as existing.
- return True
- except MissingMandatoryValue:
- # Missing values count as *not* existing.
- return False
- def __iter__(self) -> Iterator[DictKeyType]:
- return iter(self.keys())
- def items(self) -> ItemsView[DictKeyType, Any]:
- return dict(self.items_ex(resolve=True, keys=None)).items()
- def setdefault(self, key: DictKeyType, default: Any = None) -> Any:
- if key in self:
- ret = self.__getitem__(key)
- else:
- ret = default
- self.__setitem__(key, default)
- return ret
- def items_ex(
- self, resolve: bool = True, keys: Optional[Sequence[DictKeyType]] = None
- ) -> List[Tuple[DictKeyType, Any]]:
- items: List[Tuple[DictKeyType, Any]] = []
- if self._is_none():
- self._format_and_raise(
- key=None,
- value=None,
- cause=TypeError("Cannot iterate a DictConfig object representing None"),
- )
- if self._is_missing():
- raise MissingMandatoryValue("Cannot iterate a missing DictConfig")
- for key in self.keys():
- if resolve:
- value = self[key]
- else:
- value = self.__dict__["_content"][key]
- if isinstance(value, ValueNode):
- value = value._value()
- if keys is None or key in keys:
- items.append((key, value))
- return items
- def __eq__(self, other: Any) -> bool:
- if other is None:
- return self.__dict__["_content"] is None
- if is_primitive_dict(other) or is_structured_config(other):
- other = DictConfig(other, flags={"allow_objects": True})
- return DictConfig._dict_conf_eq(self, other)
- if isinstance(other, DictConfig):
- return DictConfig._dict_conf_eq(self, other)
- if self._is_missing():
- return _is_missing_literal(other)
- return NotImplemented
- def __ne__(self, other: Any) -> bool:
- x = self.__eq__(other)
- if x is not NotImplemented:
- return not x
- return NotImplemented
- def __hash__(self) -> int:
- return hash(str(self))
- def _promote(self, type_or_prototype: Optional[Type[Any]]) -> None:
- """
- Retypes a node.
- This should only be used in rare circumstances, where you want to dynamically change
- the runtime structured-type of a DictConfig.
- It will change the type and add the additional fields based on the input class or object
- """
- if type_or_prototype is None:
- return
- if not is_structured_config(type_or_prototype):
- raise ValueError(f"Expected structured config class: {type_or_prototype}")
- from omegaconf import OmegaConf
- proto: DictConfig = OmegaConf.structured(type_or_prototype)
- object_type = proto._metadata.object_type
- # remove the type to prevent assignment validation from rejecting the promotion.
- proto._metadata.object_type = None
- self.merge_with(proto)
- # restore the type.
- self._metadata.object_type = object_type
- def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
- try:
- previous_content = self.__dict__["_content"]
- self._set_value_impl(value, flags)
- except Exception as e:
- self.__dict__["_content"] = previous_content
- raise e
- def _set_value_impl(
- self, value: Any, flags: Optional[Dict[str, bool]] = None
- ) -> None:
- from omegaconf import MISSING, flag_override
- if flags is None:
- flags = {}
- assert not isinstance(value, ValueNode)
- self._validate_set(key=None, value=value)
- if _is_none(value, resolve=True):
- self.__dict__["_content"] = None
- self._metadata.object_type = None
- elif _is_interpolation(value, strict_interpolation_validation=True):
- self.__dict__["_content"] = value
- self._metadata.object_type = None
- elif _is_missing_value(value):
- self.__dict__["_content"] = MISSING
- self._metadata.object_type = None
- else:
- self.__dict__["_content"] = {}
- if is_structured_config(value):
- self._metadata.object_type = None
- ao = self._get_flag("allow_objects")
- data = get_structured_config_data(value, allow_objects=ao)
- with flag_override(self, ["struct", "readonly"], False):
- for k, v in data.items():
- self.__setitem__(k, v)
- self._metadata.object_type = get_type_of(value)
- elif isinstance(value, DictConfig):
- self._metadata.flags = copy.deepcopy(flags)
- with flag_override(self, ["struct", "readonly"], False):
- for k, v in value.__dict__["_content"].items():
- self.__setitem__(k, v)
- self._metadata.object_type = value._metadata.object_type
- elif isinstance(value, dict):
- with flag_override(self, ["struct", "readonly"], False):
- for k, v in value.items():
- self.__setitem__(k, v)
- self._metadata.object_type = dict
- else: # pragma: no cover
- msg = f"Unsupported value type: {value}"
- raise ValidationError(msg)
- @staticmethod
- def _dict_conf_eq(d1: "DictConfig", d2: "DictConfig") -> bool:
- d1_none = d1.__dict__["_content"] is None
- d2_none = d2.__dict__["_content"] is None
- if d1_none and d2_none:
- return True
- if d1_none != d2_none:
- return False
- assert isinstance(d1, DictConfig)
- assert isinstance(d2, DictConfig)
- if len(d1) != len(d2):
- return False
- if d1._is_missing() or d2._is_missing():
- return d1._is_missing() is d2._is_missing()
- for k, v in d1.items_ex(resolve=False):
- if k not in d2.__dict__["_content"]:
- return False
- if not BaseContainer._item_eq(d1, k, d2, k):
- return False
- return True
- def _to_object(self) -> Any:
- """
- Instantiate an instance of `self._metadata.object_type`.
- This requires `self` to be a structured config.
- Nested subconfigs are converted by calling `OmegaConf.to_object`.
- """
- from omegaconf import OmegaConf
- object_type = self._metadata.object_type
- assert is_structured_config(object_type)
- init_field_names = set(get_structured_config_init_field_names(object_type))
- init_field_items: Dict[str, Any] = {}
- non_init_field_items: Dict[str, Any] = {}
- for k in self.keys():
- assert isinstance(k, str)
- node = self._get_child(k)
- assert isinstance(node, Node)
- try:
- node = node._dereference_node()
- except InterpolationResolutionError as e:
- self._format_and_raise(key=k, value=None, cause=e)
- if node._is_missing():
- if k not in init_field_names:
- continue # MISSING is ignored for init=False fields
- self._format_and_raise(
- key=k,
- value=None,
- cause=MissingMandatoryValue(
- "Structured config of type `$OBJECT_TYPE` has missing mandatory value: $KEY"
- ),
- )
- if isinstance(node, Container):
- v = OmegaConf.to_object(node)
- else:
- v = node._value()
- if k in init_field_names:
- init_field_items[k] = v
- else:
- non_init_field_items[k] = v
- try:
- result = object_type(**init_field_items)
- except TypeError as exc:
- self._format_and_raise(
- key=None,
- value=None,
- cause=exc,
- msg="Could not create instance of `$OBJECT_TYPE`: " + str(exc),
- )
- for k, v in non_init_field_items.items():
- setattr(result, k, v)
- return result
|