| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679 |
- import copy
- import itertools
- from typing import (
- Any,
- Callable,
- Dict,
- Iterable,
- Iterator,
- List,
- MutableSequence,
- Optional,
- Tuple,
- Type,
- Union,
- )
- from ._utils import (
- ValueKind,
- _is_missing_literal,
- _is_none,
- _resolve_optional,
- format_and_raise,
- get_value_kind,
- is_int,
- is_primitive_list,
- is_structured_config,
- type_str,
- )
- from .base import Box, ContainerMetadata, Node
- from .basecontainer import BaseContainer
- from .errors import (
- ConfigAttributeError,
- ConfigTypeError,
- ConfigValueError,
- KeyValidationError,
- MissingMandatoryValue,
- ReadonlyConfigError,
- ValidationError,
- )
- class ListConfig(BaseContainer, MutableSequence[Any]):
- _content: Union[List[Node], None, str]
- def __init__(
- self,
- content: Union[List[Any], Tuple[Any, ...], "ListConfig", str, None],
- key: Any = None,
- parent: Optional[Box] = None,
- element_type: Union[Type[Any], Any] = Any,
- is_optional: bool = True,
- ref_type: Union[Type[Any], Any] = Any,
- flags: Optional[Dict[str, bool]] = None,
- ) -> None:
- try:
- if isinstance(content, ListConfig):
- if flags is None:
- flags = content._metadata.flags
- super().__init__(
- parent=parent,
- metadata=ContainerMetadata(
- ref_type=ref_type,
- object_type=list,
- key=key,
- optional=is_optional,
- element_type=element_type,
- key_type=int,
- flags=flags,
- ),
- )
- if isinstance(content, ListConfig):
- metadata = copy.deepcopy(content._metadata)
- metadata.key = key
- metadata.ref_type = ref_type
- metadata.optional = is_optional
- metadata.element_type = element_type
- self.__dict__["_metadata"] = metadata
- self._set_value(value=content, flags=flags)
- except Exception as ex:
- format_and_raise(node=None, key=key, value=None, cause=ex, msg=str(ex))
- def _validate_get(self, key: Any, value: Any = None) -> None:
- if not isinstance(key, (int, slice)):
- raise KeyValidationError(
- "ListConfig indices must be integers or slices, not $KEY_TYPE"
- )
- def _validate_set(self, key: Any, value: Any) -> None:
- from omegaconf import OmegaConf
- self._validate_get(key, value)
- if self._get_flag("readonly"):
- raise ReadonlyConfigError("ListConfig is read-only")
- if 0 <= key < self.__len__():
- target = self._get_node(key)
- if target is not None:
- assert isinstance(target, Node)
- if value is None and not target._is_optional():
- raise ValidationError(
- "$FULL_KEY is not optional and cannot be assigned None"
- )
- vk = get_value_kind(value)
- if vk == ValueKind.MANDATORY_MISSING:
- return
- else:
- is_optional, target_type = _resolve_optional(self._metadata.element_type)
- value_type = OmegaConf.get_type(value)
- if (value_type is None and not is_optional) or (
- is_structured_config(target_type)
- and value_type is not None
- and not issubclass(value_type, target_type)
- ):
- msg = (
- f"Invalid type assigned: {type_str(value_type)} is not a "
- f"subclass of {type_str(target_type)}. value: {value}"
- )
- raise ValidationError(msg)
- def __deepcopy__(self, memo: Dict[int, Any]) -> "ListConfig":
- res = ListConfig(None)
- res.__dict__["_metadata"] = copy.deepcopy(self.__dict__["_metadata"], memo=memo)
- res.__dict__["_flags_cache"] = copy.deepcopy(
- self.__dict__["_flags_cache"], memo=memo
- )
- src_content = self.__dict__["_content"]
- if isinstance(src_content, list):
- content_copy: List[Optional[Node]] = []
- for v in src_content:
- old_parent = v.__dict__["_parent"]
- try:
- v.__dict__["_parent"] = None
- vc = copy.deepcopy(v, memo=memo)
- vc.__dict__["_parent"] = res
- content_copy.append(vc)
- finally:
- v.__dict__["_parent"] = old_parent
- else:
- # None and strings can be assigned as is
- content_copy = src_content
- res.__dict__["_content"] = content_copy
- res.__dict__["_parent"] = self.__dict__["_parent"]
- return res
- def copy(self) -> "ListConfig":
- return copy.copy(self)
- # hide content while inspecting in debugger
- def __dir__(self) -> Iterable[str]:
- if self._is_missing() or self._is_none():
- return []
- return [str(x) for x in range(0, len(self))]
- def __setattr__(self, key: str, value: Any) -> None:
- self._format_and_raise(
- key=key,
- value=value,
- cause=ConfigAttributeError("ListConfig does not support attribute access"),
- )
- assert False
- def __getattr__(self, key: str) -> Any:
- # PyCharm is sometimes inspecting __members__, be sure to tell it we don't have that.
- if key == "__members__":
- raise AttributeError()
- if key == "__name__":
- raise AttributeError()
- if is_int(key):
- return self.__getitem__(int(key))
- else:
- self._format_and_raise(
- key=key,
- value=None,
- cause=ConfigAttributeError(
- "ListConfig does not support attribute access"
- ),
- )
- def __getitem__(self, index: Union[int, slice]) -> Any:
- try:
- if self._is_missing():
- raise MissingMandatoryValue("ListConfig is missing")
- self._validate_get(index, None)
- if self._is_none():
- raise TypeError(
- "ListConfig object representing None is not subscriptable"
- )
- assert isinstance(self.__dict__["_content"], list)
- if isinstance(index, slice):
- result = []
- start, stop, step = self._correct_index_params(index)
- for slice_idx in itertools.islice(
- range(0, len(self)), start, stop, step
- ):
- val = self._resolve_with_default(
- key=slice_idx, value=self.__dict__["_content"][slice_idx]
- )
- result.append(val)
- if index.step and index.step < 0:
- result.reverse()
- return result
- else:
- return self._resolve_with_default(
- key=index, value=self.__dict__["_content"][index]
- )
- except Exception as e:
- self._format_and_raise(key=index, value=None, cause=e)
- def _correct_index_params(self, index: slice) -> Tuple[int, int, int]:
- start = index.start
- stop = index.stop
- step = index.step
- if index.start and index.start < 0:
- start = self.__len__() + index.start
- if index.stop and index.stop < 0:
- stop = self.__len__() + index.stop
- if index.step and index.step < 0:
- step = abs(step)
- if start and stop:
- if start > stop:
- start, stop = stop + 1, start + 1
- else:
- start = stop = 0
- elif not start and stop:
- start = list(range(self.__len__() - 1, stop, -step))[0]
- stop = None
- elif start and not stop:
- stop = start + 1
- start = (stop - 1) % step
- else:
- start = (self.__len__() - 1) % step
- return start, stop, step
- def _set_at_index(self, index: Union[int, slice], value: Any) -> None:
- self._set_item_impl(index, value)
- def __setitem__(self, index: Union[int, slice], value: Any) -> None:
- try:
- if isinstance(index, slice):
- _ = iter(value) # check iterable
- self_indices = index.indices(len(self))
- indexes = range(*self_indices)
- # Ensure lengths match for extended slice assignment
- if index.step not in (None, 1):
- if len(indexes) != len(value):
- raise ValueError(
- f"attempt to assign sequence of size {len(value)}"
- f" to extended slice of size {len(indexes)}"
- )
- # Initialize insertion offsets for empty slices
- if len(indexes) == 0:
- curr_index = self_indices[0] - 1
- val_i = -1
- work_copy = self.copy() # For atomicity manipulate a copy
- # Delete and optionally replace non empty slices
- only_removed = 0
- for val_i, i in enumerate(indexes):
- curr_index = i - only_removed
- del work_copy[curr_index]
- if val_i < len(value):
- work_copy.insert(curr_index, value[val_i])
- else:
- only_removed += 1
- # Insert any remaining input items
- for val_i in range(val_i + 1, len(value)):
- curr_index += 1
- work_copy.insert(curr_index, value[val_i])
- # Reinitialize self with work_copy
- self.clear()
- self.extend(work_copy)
- else:
- self._set_at_index(index, value)
- except Exception as e:
- self._format_and_raise(key=index, value=value, cause=e)
- def append(self, item: Any) -> None:
- content = self.__dict__["_content"]
- index = len(content)
- content.append(None)
- try:
- self._set_item_impl(index, item)
- except Exception as e:
- del content[index]
- self._format_and_raise(key=index, value=item, cause=e)
- assert False
- def _update_keys(self) -> None:
- for i in range(len(self)):
- node = self._get_node(i)
- if node is not None:
- assert isinstance(node, Node)
- node._metadata.key = i
- def insert(self, index: int, item: Any) -> None:
- from omegaconf.omegaconf import _maybe_wrap
- try:
- if self._get_flag("readonly"):
- raise ReadonlyConfigError("Cannot insert into a read-only ListConfig")
- if self._is_none():
- raise TypeError(
- "Cannot insert into ListConfig object representing None"
- )
- if self._is_missing():
- raise MissingMandatoryValue("Cannot insert into missing ListConfig")
- try:
- assert isinstance(self.__dict__["_content"], list)
- # insert place holder
- self.__dict__["_content"].insert(index, None)
- is_optional, ref_type = _resolve_optional(self._metadata.element_type)
- node = _maybe_wrap(
- ref_type=ref_type,
- key=index,
- value=item,
- is_optional=is_optional,
- parent=self,
- )
- self._validate_set(key=index, value=node)
- self._set_at_index(index, node)
- self._update_keys()
- except Exception:
- del self.__dict__["_content"][index]
- self._update_keys()
- raise
- except Exception as e:
- self._format_and_raise(key=index, value=item, cause=e)
- assert False
- def extend(self, lst: Iterable[Any]) -> None:
- assert isinstance(lst, (tuple, list, ListConfig))
- for x in lst:
- self.append(x)
- def remove(self, x: Any) -> None:
- del self[self.index(x)]
- def __delitem__(self, key: Union[int, slice]) -> None:
- if self._get_flag("readonly"):
- self._format_and_raise(
- key=key,
- value=None,
- cause=ReadonlyConfigError(
- "Cannot delete item from read-only ListConfig"
- ),
- )
- del self.__dict__["_content"][key]
- self._update_keys()
- def clear(self) -> None:
- del self[:]
- def index(
- self, x: Any, start: Optional[int] = None, end: Optional[int] = None
- ) -> int:
- if start is None:
- start = 0
- if end is None:
- end = len(self)
- assert start >= 0
- assert end <= len(self)
- found_idx = -1
- for idx in range(start, end):
- item = self[idx]
- if x == item:
- found_idx = idx
- break
- if found_idx != -1:
- return found_idx
- else:
- self._format_and_raise(
- key=None,
- value=None,
- cause=ConfigValueError("Item not found in ListConfig"),
- )
- assert False
- def count(self, x: Any) -> int:
- c = 0
- for item in self:
- if item == x:
- c = c + 1
- return c
- def _get_node(
- self,
- key: Union[int, slice],
- validate_access: bool = True,
- validate_key: bool = True,
- throw_on_missing_value: bool = False,
- throw_on_missing_key: bool = False,
- ) -> Union[Optional[Node], List[Optional[Node]]]:
- try:
- if self._is_none():
- raise TypeError(
- "Cannot get_node from a ListConfig object representing None"
- )
- if self._is_missing():
- raise MissingMandatoryValue("Cannot get_node from a missing ListConfig")
- assert isinstance(self.__dict__["_content"], list)
- if validate_access:
- self._validate_get(key)
- value = self.__dict__["_content"][key]
- if value is not None:
- if isinstance(key, slice):
- assert isinstance(value, list)
- for v in value:
- if throw_on_missing_value and v._is_missing():
- raise MissingMandatoryValue("Missing mandatory value")
- else:
- assert isinstance(value, Node)
- if throw_on_missing_value and value._is_missing():
- raise MissingMandatoryValue("Missing mandatory value: $KEY")
- return value
- except (IndexError, TypeError, MissingMandatoryValue, KeyValidationError) as e:
- if isinstance(e, MissingMandatoryValue) and throw_on_missing_value:
- raise
- if validate_access:
- self._format_and_raise(key=key, value=None, cause=e)
- assert False
- else:
- return None
- def get(self, index: int, default_value: Any = None) -> Any:
- try:
- if self._is_none():
- raise TypeError("Cannot get from a ListConfig object representing None")
- if self._is_missing():
- raise MissingMandatoryValue("Cannot get from a missing ListConfig")
- self._validate_get(index, None)
- assert isinstance(self.__dict__["_content"], list)
- return self._resolve_with_default(
- key=index,
- value=self.__dict__["_content"][index],
- default_value=default_value,
- )
- except Exception as e:
- self._format_and_raise(key=index, value=None, cause=e)
- assert False
- def pop(self, index: int = -1) -> Any:
- try:
- if self._get_flag("readonly"):
- raise ReadonlyConfigError("Cannot pop from read-only ListConfig")
- if self._is_none():
- raise TypeError("Cannot pop from a ListConfig object representing None")
- if self._is_missing():
- raise MissingMandatoryValue("Cannot pop from a missing ListConfig")
- assert isinstance(self.__dict__["_content"], list)
- node = self._get_child(index)
- assert isinstance(node, Node)
- ret = self._resolve_with_default(key=index, value=node, default_value=None)
- del self.__dict__["_content"][index]
- self._update_keys()
- return ret
- except KeyValidationError as e:
- self._format_and_raise(
- key=index, value=None, cause=e, type_override=ConfigTypeError
- )
- assert False
- except Exception as e:
- self._format_and_raise(key=index, value=None, cause=e)
- assert False
- def sort(
- self, key: Optional[Callable[[Any], Any]] = None, reverse: bool = False
- ) -> None:
- try:
- if self._get_flag("readonly"):
- raise ReadonlyConfigError("Cannot sort a read-only ListConfig")
- if self._is_none():
- raise TypeError("Cannot sort a ListConfig object representing None")
- if self._is_missing():
- raise MissingMandatoryValue("Cannot sort a missing ListConfig")
- if key is None:
- def key1(x: Any) -> Any:
- return x._value()
- else:
- def key1(x: Any) -> Any:
- return key(x._value()) # type: ignore
- assert isinstance(self.__dict__["_content"], list)
- self.__dict__["_content"].sort(key=key1, reverse=reverse)
- except Exception as e:
- self._format_and_raise(key=None, value=None, cause=e)
- assert False
- def __eq__(self, other: Any) -> bool:
- if isinstance(other, (list, tuple)) or other is None:
- other = ListConfig(other, flags={"allow_objects": True})
- return ListConfig._list_eq(self, other)
- if other is None or isinstance(other, ListConfig):
- return ListConfig._list_eq(self, other)
- if self._is_missing():
- return _is_missing_literal(other)
- return NotImplemented
- def __ne__(self, other: Any) -> bool:
- x = self.__eq__(other)
- if x is not NotImplemented:
- return not x
- return NotImplemented
- def __hash__(self) -> int:
- return hash(str(self))
- def __iter__(self) -> Iterator[Any]:
- return self._iter_ex(resolve=True)
- class ListIterator(Iterator[Any]):
- def __init__(self, lst: Any, resolve: bool) -> None:
- self.resolve = resolve
- self.iterator = iter(lst.__dict__["_content"])
- self.index = 0
- from .nodes import ValueNode
- self.ValueNode = ValueNode
- def __next__(self) -> Any:
- x = next(self.iterator)
- if self.resolve:
- x = x._dereference_node()
- if x._is_missing():
- raise MissingMandatoryValue(f"Missing value at index {self.index}")
- self.index = self.index + 1
- if isinstance(x, self.ValueNode):
- return x._value()
- else:
- # Must be omegaconf.Container. not checking for perf reasons.
- if x._is_none():
- return None
- return x
- def __repr__(self) -> str: # pragma: no cover
- return f"ListConfig.ListIterator(resolve={self.resolve})"
- def _iter_ex(self, resolve: bool) -> Iterator[Any]:
- try:
- if self._is_none():
- raise TypeError("Cannot iterate a ListConfig object representing None")
- if self._is_missing():
- raise MissingMandatoryValue("Cannot iterate a missing ListConfig")
- return ListConfig.ListIterator(self, resolve)
- except (TypeError, MissingMandatoryValue) as e:
- self._format_and_raise(key=None, value=None, cause=e)
- assert False
- def __add__(self, other: Union[List[Any], "ListConfig"]) -> "ListConfig":
- # res is sharing this list's parent to allow interpolation to work as expected
- res = ListConfig(parent=self._get_parent(), content=[])
- res.extend(self)
- res.extend(other)
- return res
- def __radd__(self, other: Union[List[Any], "ListConfig"]) -> "ListConfig":
- # res is sharing this list's parent to allow interpolation to work as expected
- res = ListConfig(parent=self._get_parent(), content=[])
- res.extend(other)
- res.extend(self)
- return res
- def __iadd__(self, other: Iterable[Any]) -> "ListConfig":
- self.extend(other)
- return self
- def __contains__(self, item: Any) -> bool:
- if self._is_none():
- raise TypeError(
- "Cannot check if an item is in a ListConfig object representing None"
- )
- if self._is_missing():
- raise MissingMandatoryValue(
- "Cannot check if an item is in missing ListConfig"
- )
- lst = self.__dict__["_content"]
- for x in lst:
- x = x._dereference_node()
- if x == item:
- return True
- return False
- def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
- try:
- previous_content = self.__dict__["_content"]
- previous_metadata = self.__dict__["_metadata"]
- self._set_value_impl(value, flags)
- except Exception as e:
- self.__dict__["_content"] = previous_content
- self.__dict__["_metadata"] = previous_metadata
- raise e
- def _set_value_impl(
- self, value: Any, flags: Optional[Dict[str, bool]] = None
- ) -> None:
- from omegaconf import MISSING, flag_override
- if flags is None:
- flags = {}
- vk = get_value_kind(value, strict_interpolation_validation=True)
- if _is_none(value):
- if not self._is_optional():
- raise ValidationError(
- "Non optional ListConfig cannot be constructed from None"
- )
- self.__dict__["_content"] = None
- self._metadata.object_type = None
- elif vk is ValueKind.MANDATORY_MISSING:
- self.__dict__["_content"] = MISSING
- self._metadata.object_type = None
- elif vk == ValueKind.INTERPOLATION:
- self.__dict__["_content"] = value
- self._metadata.object_type = None
- else:
- if not (is_primitive_list(value) or isinstance(value, ListConfig)):
- type_ = type(value)
- msg = f"Invalid value assigned: {type_.__name__} is not a ListConfig, list or tuple."
- raise ValidationError(msg)
- self.__dict__["_content"] = []
- if isinstance(value, ListConfig):
- self._metadata.flags = copy.deepcopy(flags)
- # disable struct and readonly for the construction phase
- # retaining other flags like allow_objects. The real flags are restored at the end of this function
- with flag_override(self, ["struct", "readonly"], False):
- for item in value._iter_ex(resolve=False):
- self.append(item)
- elif is_primitive_list(value):
- with flag_override(self, ["struct", "readonly"], False):
- for item in value:
- self.append(item)
- self._metadata.object_type = list
- @staticmethod
- def _list_eq(l1: Optional["ListConfig"], l2: Optional["ListConfig"]) -> bool:
- l1_none = l1.__dict__["_content"] is None
- l2_none = l2.__dict__["_content"] is None
- if l1_none and l2_none:
- return True
- if l1_none != l2_none:
- return False
- assert isinstance(l1, ListConfig)
- assert isinstance(l2, ListConfig)
- if len(l1) != len(l2):
- return False
- for i in range(len(l1)):
- if not BaseContainer._item_eq(l1, i, l2, i):
- return False
- return True
|