listconfig.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679
  1. import copy
  2. import itertools
  3. from typing import (
  4. Any,
  5. Callable,
  6. Dict,
  7. Iterable,
  8. Iterator,
  9. List,
  10. MutableSequence,
  11. Optional,
  12. Tuple,
  13. Type,
  14. Union,
  15. )
  16. from ._utils import (
  17. ValueKind,
  18. _is_missing_literal,
  19. _is_none,
  20. _resolve_optional,
  21. format_and_raise,
  22. get_value_kind,
  23. is_int,
  24. is_primitive_list,
  25. is_structured_config,
  26. type_str,
  27. )
  28. from .base import Box, ContainerMetadata, Node
  29. from .basecontainer import BaseContainer
  30. from .errors import (
  31. ConfigAttributeError,
  32. ConfigTypeError,
  33. ConfigValueError,
  34. KeyValidationError,
  35. MissingMandatoryValue,
  36. ReadonlyConfigError,
  37. ValidationError,
  38. )
  39. class ListConfig(BaseContainer, MutableSequence[Any]):
  40. _content: Union[List[Node], None, str]
  41. def __init__(
  42. self,
  43. content: Union[List[Any], Tuple[Any, ...], "ListConfig", str, None],
  44. key: Any = None,
  45. parent: Optional[Box] = None,
  46. element_type: Union[Type[Any], Any] = Any,
  47. is_optional: bool = True,
  48. ref_type: Union[Type[Any], Any] = Any,
  49. flags: Optional[Dict[str, bool]] = None,
  50. ) -> None:
  51. try:
  52. if isinstance(content, ListConfig):
  53. if flags is None:
  54. flags = content._metadata.flags
  55. super().__init__(
  56. parent=parent,
  57. metadata=ContainerMetadata(
  58. ref_type=ref_type,
  59. object_type=list,
  60. key=key,
  61. optional=is_optional,
  62. element_type=element_type,
  63. key_type=int,
  64. flags=flags,
  65. ),
  66. )
  67. if isinstance(content, ListConfig):
  68. metadata = copy.deepcopy(content._metadata)
  69. metadata.key = key
  70. metadata.ref_type = ref_type
  71. metadata.optional = is_optional
  72. metadata.element_type = element_type
  73. self.__dict__["_metadata"] = metadata
  74. self._set_value(value=content, flags=flags)
  75. except Exception as ex:
  76. format_and_raise(node=None, key=key, value=None, cause=ex, msg=str(ex))
  77. def _validate_get(self, key: Any, value: Any = None) -> None:
  78. if not isinstance(key, (int, slice)):
  79. raise KeyValidationError(
  80. "ListConfig indices must be integers or slices, not $KEY_TYPE"
  81. )
  82. def _validate_set(self, key: Any, value: Any) -> None:
  83. from omegaconf import OmegaConf
  84. self._validate_get(key, value)
  85. if self._get_flag("readonly"):
  86. raise ReadonlyConfigError("ListConfig is read-only")
  87. if 0 <= key < self.__len__():
  88. target = self._get_node(key)
  89. if target is not None:
  90. assert isinstance(target, Node)
  91. if value is None and not target._is_optional():
  92. raise ValidationError(
  93. "$FULL_KEY is not optional and cannot be assigned None"
  94. )
  95. vk = get_value_kind(value)
  96. if vk == ValueKind.MANDATORY_MISSING:
  97. return
  98. else:
  99. is_optional, target_type = _resolve_optional(self._metadata.element_type)
  100. value_type = OmegaConf.get_type(value)
  101. if (value_type is None and not is_optional) or (
  102. is_structured_config(target_type)
  103. and value_type is not None
  104. and not issubclass(value_type, target_type)
  105. ):
  106. msg = (
  107. f"Invalid type assigned: {type_str(value_type)} is not a "
  108. f"subclass of {type_str(target_type)}. value: {value}"
  109. )
  110. raise ValidationError(msg)
  111. def __deepcopy__(self, memo: Dict[int, Any]) -> "ListConfig":
  112. res = ListConfig(None)
  113. res.__dict__["_metadata"] = copy.deepcopy(self.__dict__["_metadata"], memo=memo)
  114. res.__dict__["_flags_cache"] = copy.deepcopy(
  115. self.__dict__["_flags_cache"], memo=memo
  116. )
  117. src_content = self.__dict__["_content"]
  118. if isinstance(src_content, list):
  119. content_copy: List[Optional[Node]] = []
  120. for v in src_content:
  121. old_parent = v.__dict__["_parent"]
  122. try:
  123. v.__dict__["_parent"] = None
  124. vc = copy.deepcopy(v, memo=memo)
  125. vc.__dict__["_parent"] = res
  126. content_copy.append(vc)
  127. finally:
  128. v.__dict__["_parent"] = old_parent
  129. else:
  130. # None and strings can be assigned as is
  131. content_copy = src_content
  132. res.__dict__["_content"] = content_copy
  133. res.__dict__["_parent"] = self.__dict__["_parent"]
  134. return res
  135. def copy(self) -> "ListConfig":
  136. return copy.copy(self)
  137. # hide content while inspecting in debugger
  138. def __dir__(self) -> Iterable[str]:
  139. if self._is_missing() or self._is_none():
  140. return []
  141. return [str(x) for x in range(0, len(self))]
  142. def __setattr__(self, key: str, value: Any) -> None:
  143. self._format_and_raise(
  144. key=key,
  145. value=value,
  146. cause=ConfigAttributeError("ListConfig does not support attribute access"),
  147. )
  148. assert False
  149. def __getattr__(self, key: str) -> Any:
  150. # PyCharm is sometimes inspecting __members__, be sure to tell it we don't have that.
  151. if key == "__members__":
  152. raise AttributeError()
  153. if key == "__name__":
  154. raise AttributeError()
  155. if is_int(key):
  156. return self.__getitem__(int(key))
  157. else:
  158. self._format_and_raise(
  159. key=key,
  160. value=None,
  161. cause=ConfigAttributeError(
  162. "ListConfig does not support attribute access"
  163. ),
  164. )
  165. def __getitem__(self, index: Union[int, slice]) -> Any:
  166. try:
  167. if self._is_missing():
  168. raise MissingMandatoryValue("ListConfig is missing")
  169. self._validate_get(index, None)
  170. if self._is_none():
  171. raise TypeError(
  172. "ListConfig object representing None is not subscriptable"
  173. )
  174. assert isinstance(self.__dict__["_content"], list)
  175. if isinstance(index, slice):
  176. result = []
  177. start, stop, step = self._correct_index_params(index)
  178. for slice_idx in itertools.islice(
  179. range(0, len(self)), start, stop, step
  180. ):
  181. val = self._resolve_with_default(
  182. key=slice_idx, value=self.__dict__["_content"][slice_idx]
  183. )
  184. result.append(val)
  185. if index.step and index.step < 0:
  186. result.reverse()
  187. return result
  188. else:
  189. return self._resolve_with_default(
  190. key=index, value=self.__dict__["_content"][index]
  191. )
  192. except Exception as e:
  193. self._format_and_raise(key=index, value=None, cause=e)
  194. def _correct_index_params(self, index: slice) -> Tuple[int, int, int]:
  195. start = index.start
  196. stop = index.stop
  197. step = index.step
  198. if index.start and index.start < 0:
  199. start = self.__len__() + index.start
  200. if index.stop and index.stop < 0:
  201. stop = self.__len__() + index.stop
  202. if index.step and index.step < 0:
  203. step = abs(step)
  204. if start and stop:
  205. if start > stop:
  206. start, stop = stop + 1, start + 1
  207. else:
  208. start = stop = 0
  209. elif not start and stop:
  210. start = list(range(self.__len__() - 1, stop, -step))[0]
  211. stop = None
  212. elif start and not stop:
  213. stop = start + 1
  214. start = (stop - 1) % step
  215. else:
  216. start = (self.__len__() - 1) % step
  217. return start, stop, step
  218. def _set_at_index(self, index: Union[int, slice], value: Any) -> None:
  219. self._set_item_impl(index, value)
  220. def __setitem__(self, index: Union[int, slice], value: Any) -> None:
  221. try:
  222. if isinstance(index, slice):
  223. _ = iter(value) # check iterable
  224. self_indices = index.indices(len(self))
  225. indexes = range(*self_indices)
  226. # Ensure lengths match for extended slice assignment
  227. if index.step not in (None, 1):
  228. if len(indexes) != len(value):
  229. raise ValueError(
  230. f"attempt to assign sequence of size {len(value)}"
  231. f" to extended slice of size {len(indexes)}"
  232. )
  233. # Initialize insertion offsets for empty slices
  234. if len(indexes) == 0:
  235. curr_index = self_indices[0] - 1
  236. val_i = -1
  237. work_copy = self.copy() # For atomicity manipulate a copy
  238. # Delete and optionally replace non empty slices
  239. only_removed = 0
  240. for val_i, i in enumerate(indexes):
  241. curr_index = i - only_removed
  242. del work_copy[curr_index]
  243. if val_i < len(value):
  244. work_copy.insert(curr_index, value[val_i])
  245. else:
  246. only_removed += 1
  247. # Insert any remaining input items
  248. for val_i in range(val_i + 1, len(value)):
  249. curr_index += 1
  250. work_copy.insert(curr_index, value[val_i])
  251. # Reinitialize self with work_copy
  252. self.clear()
  253. self.extend(work_copy)
  254. else:
  255. self._set_at_index(index, value)
  256. except Exception as e:
  257. self._format_and_raise(key=index, value=value, cause=e)
  258. def append(self, item: Any) -> None:
  259. content = self.__dict__["_content"]
  260. index = len(content)
  261. content.append(None)
  262. try:
  263. self._set_item_impl(index, item)
  264. except Exception as e:
  265. del content[index]
  266. self._format_and_raise(key=index, value=item, cause=e)
  267. assert False
  268. def _update_keys(self) -> None:
  269. for i in range(len(self)):
  270. node = self._get_node(i)
  271. if node is not None:
  272. assert isinstance(node, Node)
  273. node._metadata.key = i
  274. def insert(self, index: int, item: Any) -> None:
  275. from omegaconf.omegaconf import _maybe_wrap
  276. try:
  277. if self._get_flag("readonly"):
  278. raise ReadonlyConfigError("Cannot insert into a read-only ListConfig")
  279. if self._is_none():
  280. raise TypeError(
  281. "Cannot insert into ListConfig object representing None"
  282. )
  283. if self._is_missing():
  284. raise MissingMandatoryValue("Cannot insert into missing ListConfig")
  285. try:
  286. assert isinstance(self.__dict__["_content"], list)
  287. # insert place holder
  288. self.__dict__["_content"].insert(index, None)
  289. is_optional, ref_type = _resolve_optional(self._metadata.element_type)
  290. node = _maybe_wrap(
  291. ref_type=ref_type,
  292. key=index,
  293. value=item,
  294. is_optional=is_optional,
  295. parent=self,
  296. )
  297. self._validate_set(key=index, value=node)
  298. self._set_at_index(index, node)
  299. self._update_keys()
  300. except Exception:
  301. del self.__dict__["_content"][index]
  302. self._update_keys()
  303. raise
  304. except Exception as e:
  305. self._format_and_raise(key=index, value=item, cause=e)
  306. assert False
  307. def extend(self, lst: Iterable[Any]) -> None:
  308. assert isinstance(lst, (tuple, list, ListConfig))
  309. for x in lst:
  310. self.append(x)
  311. def remove(self, x: Any) -> None:
  312. del self[self.index(x)]
  313. def __delitem__(self, key: Union[int, slice]) -> None:
  314. if self._get_flag("readonly"):
  315. self._format_and_raise(
  316. key=key,
  317. value=None,
  318. cause=ReadonlyConfigError(
  319. "Cannot delete item from read-only ListConfig"
  320. ),
  321. )
  322. del self.__dict__["_content"][key]
  323. self._update_keys()
  324. def clear(self) -> None:
  325. del self[:]
  326. def index(
  327. self, x: Any, start: Optional[int] = None, end: Optional[int] = None
  328. ) -> int:
  329. if start is None:
  330. start = 0
  331. if end is None:
  332. end = len(self)
  333. assert start >= 0
  334. assert end <= len(self)
  335. found_idx = -1
  336. for idx in range(start, end):
  337. item = self[idx]
  338. if x == item:
  339. found_idx = idx
  340. break
  341. if found_idx != -1:
  342. return found_idx
  343. else:
  344. self._format_and_raise(
  345. key=None,
  346. value=None,
  347. cause=ConfigValueError("Item not found in ListConfig"),
  348. )
  349. assert False
  350. def count(self, x: Any) -> int:
  351. c = 0
  352. for item in self:
  353. if item == x:
  354. c = c + 1
  355. return c
  356. def _get_node(
  357. self,
  358. key: Union[int, slice],
  359. validate_access: bool = True,
  360. validate_key: bool = True,
  361. throw_on_missing_value: bool = False,
  362. throw_on_missing_key: bool = False,
  363. ) -> Union[Optional[Node], List[Optional[Node]]]:
  364. try:
  365. if self._is_none():
  366. raise TypeError(
  367. "Cannot get_node from a ListConfig object representing None"
  368. )
  369. if self._is_missing():
  370. raise MissingMandatoryValue("Cannot get_node from a missing ListConfig")
  371. assert isinstance(self.__dict__["_content"], list)
  372. if validate_access:
  373. self._validate_get(key)
  374. value = self.__dict__["_content"][key]
  375. if value is not None:
  376. if isinstance(key, slice):
  377. assert isinstance(value, list)
  378. for v in value:
  379. if throw_on_missing_value and v._is_missing():
  380. raise MissingMandatoryValue("Missing mandatory value")
  381. else:
  382. assert isinstance(value, Node)
  383. if throw_on_missing_value and value._is_missing():
  384. raise MissingMandatoryValue("Missing mandatory value: $KEY")
  385. return value
  386. except (IndexError, TypeError, MissingMandatoryValue, KeyValidationError) as e:
  387. if isinstance(e, MissingMandatoryValue) and throw_on_missing_value:
  388. raise
  389. if validate_access:
  390. self._format_and_raise(key=key, value=None, cause=e)
  391. assert False
  392. else:
  393. return None
  394. def get(self, index: int, default_value: Any = None) -> Any:
  395. try:
  396. if self._is_none():
  397. raise TypeError("Cannot get from a ListConfig object representing None")
  398. if self._is_missing():
  399. raise MissingMandatoryValue("Cannot get from a missing ListConfig")
  400. self._validate_get(index, None)
  401. assert isinstance(self.__dict__["_content"], list)
  402. return self._resolve_with_default(
  403. key=index,
  404. value=self.__dict__["_content"][index],
  405. default_value=default_value,
  406. )
  407. except Exception as e:
  408. self._format_and_raise(key=index, value=None, cause=e)
  409. assert False
  410. def pop(self, index: int = -1) -> Any:
  411. try:
  412. if self._get_flag("readonly"):
  413. raise ReadonlyConfigError("Cannot pop from read-only ListConfig")
  414. if self._is_none():
  415. raise TypeError("Cannot pop from a ListConfig object representing None")
  416. if self._is_missing():
  417. raise MissingMandatoryValue("Cannot pop from a missing ListConfig")
  418. assert isinstance(self.__dict__["_content"], list)
  419. node = self._get_child(index)
  420. assert isinstance(node, Node)
  421. ret = self._resolve_with_default(key=index, value=node, default_value=None)
  422. del self.__dict__["_content"][index]
  423. self._update_keys()
  424. return ret
  425. except KeyValidationError as e:
  426. self._format_and_raise(
  427. key=index, value=None, cause=e, type_override=ConfigTypeError
  428. )
  429. assert False
  430. except Exception as e:
  431. self._format_and_raise(key=index, value=None, cause=e)
  432. assert False
  433. def sort(
  434. self, key: Optional[Callable[[Any], Any]] = None, reverse: bool = False
  435. ) -> None:
  436. try:
  437. if self._get_flag("readonly"):
  438. raise ReadonlyConfigError("Cannot sort a read-only ListConfig")
  439. if self._is_none():
  440. raise TypeError("Cannot sort a ListConfig object representing None")
  441. if self._is_missing():
  442. raise MissingMandatoryValue("Cannot sort a missing ListConfig")
  443. if key is None:
  444. def key1(x: Any) -> Any:
  445. return x._value()
  446. else:
  447. def key1(x: Any) -> Any:
  448. return key(x._value()) # type: ignore
  449. assert isinstance(self.__dict__["_content"], list)
  450. self.__dict__["_content"].sort(key=key1, reverse=reverse)
  451. except Exception as e:
  452. self._format_and_raise(key=None, value=None, cause=e)
  453. assert False
  454. def __eq__(self, other: Any) -> bool:
  455. if isinstance(other, (list, tuple)) or other is None:
  456. other = ListConfig(other, flags={"allow_objects": True})
  457. return ListConfig._list_eq(self, other)
  458. if other is None or isinstance(other, ListConfig):
  459. return ListConfig._list_eq(self, other)
  460. if self._is_missing():
  461. return _is_missing_literal(other)
  462. return NotImplemented
  463. def __ne__(self, other: Any) -> bool:
  464. x = self.__eq__(other)
  465. if x is not NotImplemented:
  466. return not x
  467. return NotImplemented
  468. def __hash__(self) -> int:
  469. return hash(str(self))
  470. def __iter__(self) -> Iterator[Any]:
  471. return self._iter_ex(resolve=True)
  472. class ListIterator(Iterator[Any]):
  473. def __init__(self, lst: Any, resolve: bool) -> None:
  474. self.resolve = resolve
  475. self.iterator = iter(lst.__dict__["_content"])
  476. self.index = 0
  477. from .nodes import ValueNode
  478. self.ValueNode = ValueNode
  479. def __next__(self) -> Any:
  480. x = next(self.iterator)
  481. if self.resolve:
  482. x = x._dereference_node()
  483. if x._is_missing():
  484. raise MissingMandatoryValue(f"Missing value at index {self.index}")
  485. self.index = self.index + 1
  486. if isinstance(x, self.ValueNode):
  487. return x._value()
  488. else:
  489. # Must be omegaconf.Container. not checking for perf reasons.
  490. if x._is_none():
  491. return None
  492. return x
  493. def __repr__(self) -> str: # pragma: no cover
  494. return f"ListConfig.ListIterator(resolve={self.resolve})"
  495. def _iter_ex(self, resolve: bool) -> Iterator[Any]:
  496. try:
  497. if self._is_none():
  498. raise TypeError("Cannot iterate a ListConfig object representing None")
  499. if self._is_missing():
  500. raise MissingMandatoryValue("Cannot iterate a missing ListConfig")
  501. return ListConfig.ListIterator(self, resolve)
  502. except (TypeError, MissingMandatoryValue) as e:
  503. self._format_and_raise(key=None, value=None, cause=e)
  504. assert False
  505. def __add__(self, other: Union[List[Any], "ListConfig"]) -> "ListConfig":
  506. # res is sharing this list's parent to allow interpolation to work as expected
  507. res = ListConfig(parent=self._get_parent(), content=[])
  508. res.extend(self)
  509. res.extend(other)
  510. return res
  511. def __radd__(self, other: Union[List[Any], "ListConfig"]) -> "ListConfig":
  512. # res is sharing this list's parent to allow interpolation to work as expected
  513. res = ListConfig(parent=self._get_parent(), content=[])
  514. res.extend(other)
  515. res.extend(self)
  516. return res
  517. def __iadd__(self, other: Iterable[Any]) -> "ListConfig":
  518. self.extend(other)
  519. return self
  520. def __contains__(self, item: Any) -> bool:
  521. if self._is_none():
  522. raise TypeError(
  523. "Cannot check if an item is in a ListConfig object representing None"
  524. )
  525. if self._is_missing():
  526. raise MissingMandatoryValue(
  527. "Cannot check if an item is in missing ListConfig"
  528. )
  529. lst = self.__dict__["_content"]
  530. for x in lst:
  531. x = x._dereference_node()
  532. if x == item:
  533. return True
  534. return False
  535. def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
  536. try:
  537. previous_content = self.__dict__["_content"]
  538. previous_metadata = self.__dict__["_metadata"]
  539. self._set_value_impl(value, flags)
  540. except Exception as e:
  541. self.__dict__["_content"] = previous_content
  542. self.__dict__["_metadata"] = previous_metadata
  543. raise e
  544. def _set_value_impl(
  545. self, value: Any, flags: Optional[Dict[str, bool]] = None
  546. ) -> None:
  547. from omegaconf import MISSING, flag_override
  548. if flags is None:
  549. flags = {}
  550. vk = get_value_kind(value, strict_interpolation_validation=True)
  551. if _is_none(value):
  552. if not self._is_optional():
  553. raise ValidationError(
  554. "Non optional ListConfig cannot be constructed from None"
  555. )
  556. self.__dict__["_content"] = None
  557. self._metadata.object_type = None
  558. elif vk is ValueKind.MANDATORY_MISSING:
  559. self.__dict__["_content"] = MISSING
  560. self._metadata.object_type = None
  561. elif vk == ValueKind.INTERPOLATION:
  562. self.__dict__["_content"] = value
  563. self._metadata.object_type = None
  564. else:
  565. if not (is_primitive_list(value) or isinstance(value, ListConfig)):
  566. type_ = type(value)
  567. msg = f"Invalid value assigned: {type_.__name__} is not a ListConfig, list or tuple."
  568. raise ValidationError(msg)
  569. self.__dict__["_content"] = []
  570. if isinstance(value, ListConfig):
  571. self._metadata.flags = copy.deepcopy(flags)
  572. # disable struct and readonly for the construction phase
  573. # retaining other flags like allow_objects. The real flags are restored at the end of this function
  574. with flag_override(self, ["struct", "readonly"], False):
  575. for item in value._iter_ex(resolve=False):
  576. self.append(item)
  577. elif is_primitive_list(value):
  578. with flag_override(self, ["struct", "readonly"], False):
  579. for item in value:
  580. self.append(item)
  581. self._metadata.object_type = list
  582. @staticmethod
  583. def _list_eq(l1: Optional["ListConfig"], l2: Optional["ListConfig"]) -> bool:
  584. l1_none = l1.__dict__["_content"] is None
  585. l2_none = l2.__dict__["_content"] is None
  586. if l1_none and l2_none:
  587. return True
  588. if l1_none != l2_none:
  589. return False
  590. assert isinstance(l1, ListConfig)
  591. assert isinstance(l2, ListConfig)
  592. if len(l1) != len(l2):
  593. return False
  594. for i in range(len(l1)):
  595. if not BaseContainer._item_eq(l1, i, l2, i):
  596. return False
  597. return True