dict.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. import copy
  2. from collections import deque
  3. from collections.abc import Mapping, Sequence
  4. from typing import Dict, List, Optional, TypeVar, Union
  5. from ray.util.annotations import Deprecated
  6. T = TypeVar("T")
  7. @Deprecated
  8. def merge_dicts(d1: dict, d2: dict) -> dict:
  9. """
  10. Args:
  11. d1 (dict): Dict 1.
  12. d2 (dict): Dict 2.
  13. Returns:
  14. dict: A new dict that is d1 and d2 deep merged.
  15. """
  16. merged = copy.deepcopy(d1)
  17. deep_update(merged, d2, True, [])
  18. return merged
  19. @Deprecated
  20. def deep_update(
  21. original: dict,
  22. new_dict: dict,
  23. new_keys_allowed: bool = False,
  24. allow_new_subkey_list: Optional[List[str]] = None,
  25. override_all_if_type_changes: Optional[List[str]] = None,
  26. override_all_key_list: Optional[List[str]] = None,
  27. ) -> dict:
  28. """Updates original dict with values from new_dict recursively.
  29. If new key is introduced in new_dict, then if new_keys_allowed is not
  30. True, an error will be thrown. Further, for sub-dicts, if the key is
  31. in the allow_new_subkey_list, then new subkeys can be introduced.
  32. Args:
  33. original: Dictionary with default values.
  34. new_dict: Dictionary with values to be updated
  35. new_keys_allowed: Whether new keys are allowed.
  36. allow_new_subkey_list: List of keys that
  37. correspond to dict values where new subkeys can be introduced.
  38. This is only at the top level.
  39. override_all_if_type_changes: List of top level
  40. keys with value=dict, for which we always simply override the
  41. entire value (dict), iff the "type" key in that value dict changes.
  42. override_all_key_list: List of top level keys
  43. for which we override the entire value if the key is in the new_dict.
  44. """
  45. allow_new_subkey_list = allow_new_subkey_list or []
  46. override_all_if_type_changes = override_all_if_type_changes or []
  47. override_all_key_list = override_all_key_list or []
  48. for k, value in new_dict.items():
  49. if k not in original and not new_keys_allowed:
  50. raise Exception("Unknown config parameter `{}` ".format(k))
  51. # Both orginal value and new one are dicts.
  52. if (
  53. isinstance(original.get(k), dict)
  54. and isinstance(value, dict)
  55. and k not in override_all_key_list
  56. ):
  57. # Check old type vs old one. If different, override entire value.
  58. if (
  59. k in override_all_if_type_changes
  60. and "type" in value
  61. and "type" in original[k]
  62. and value["type"] != original[k]["type"]
  63. ):
  64. original[k] = value
  65. # Allowed key -> ok to add new subkeys.
  66. elif k in allow_new_subkey_list:
  67. deep_update(
  68. original[k],
  69. value,
  70. True,
  71. override_all_key_list=override_all_key_list,
  72. )
  73. # Non-allowed key.
  74. else:
  75. deep_update(
  76. original[k],
  77. value,
  78. new_keys_allowed,
  79. override_all_key_list=override_all_key_list,
  80. )
  81. # Original value not a dict OR new value not a dict:
  82. # Override entire value.
  83. else:
  84. original[k] = value
  85. return original
  86. @Deprecated
  87. def flatten_dict(
  88. dt: Dict,
  89. delimiter: str = "/",
  90. prevent_delimiter: bool = False,
  91. flatten_list: bool = False,
  92. ):
  93. """Flatten dict.
  94. Output and input are of the same dict type.
  95. Input dict remains the same after the operation.
  96. """
  97. def _raise_delimiter_exception():
  98. raise ValueError(
  99. f"Found delimiter `{delimiter}` in key when trying to flatten "
  100. f"array. Please avoid using the delimiter in your specification."
  101. )
  102. dt = copy.copy(dt)
  103. if prevent_delimiter and any(delimiter in key for key in dt):
  104. # Raise if delimiter is any of the keys
  105. _raise_delimiter_exception()
  106. while_check = (dict, list) if flatten_list else dict
  107. while any(isinstance(v, while_check) for v in dt.values()):
  108. remove = []
  109. add = {}
  110. for key, value in dt.items():
  111. if isinstance(value, dict):
  112. for subkey, v in value.items():
  113. if prevent_delimiter and delimiter in subkey:
  114. # Raise if delimiter is in any of the subkeys
  115. _raise_delimiter_exception()
  116. add[delimiter.join([key, str(subkey)])] = v
  117. remove.append(key)
  118. elif flatten_list and isinstance(value, list):
  119. for i, v in enumerate(value):
  120. if prevent_delimiter and delimiter in subkey:
  121. # Raise if delimiter is in any of the subkeys
  122. _raise_delimiter_exception()
  123. add[delimiter.join([key, str(i)])] = v
  124. remove.append(key)
  125. dt.update(add)
  126. for k in remove:
  127. del dt[k]
  128. return dt
  129. @Deprecated
  130. def unflatten_dict(dt: Dict[str, T], delimiter: str = "/") -> Dict[str, T]:
  131. """Unflatten dict. Does not support unflattening lists."""
  132. dict_type = type(dt)
  133. out = dict_type()
  134. for key, val in dt.items():
  135. path = key.split(delimiter)
  136. item = out
  137. for k in path[:-1]:
  138. item = item.setdefault(k, dict_type())
  139. if not isinstance(item, dict_type):
  140. raise TypeError(
  141. f"Cannot unflatten dict due the key '{key}' "
  142. f"having a parent key '{k}', which value is not "
  143. f"of type {dict_type} (got {type(item)}). "
  144. "Change the key names to resolve the conflict."
  145. )
  146. item[path[-1]] = val
  147. return out
  148. @Deprecated
  149. def unflatten_list_dict(dt: Dict[str, T], delimiter: str = "/") -> Dict[str, T]:
  150. """Unflatten nested dict and list.
  151. This function now has some limitations:
  152. (1) The keys of dt must be str.
  153. (2) If unflattened dt (the result) contains list, the index order must be
  154. ascending when accessing dt. Otherwise, this function will throw
  155. AssertionError.
  156. (3) The unflattened dt (the result) shouldn't contain dict with number
  157. keys.
  158. Be careful to use this function. If you want to improve this function,
  159. please also improve the unit test. See #14487 for more details.
  160. Args:
  161. dt: Flattened dictionary that is originally nested by multiple
  162. list and dict.
  163. delimiter: Delimiter of keys.
  164. Example:
  165. >>> dt = {"aaa/0/bb": 12, "aaa/1/cc": 56, "aaa/1/dd": 92}
  166. >>> unflatten_list_dict(dt)
  167. {'aaa': [{'bb': 12}, {'cc': 56, 'dd': 92}]}
  168. """
  169. out_type = list if list(dt)[0].split(delimiter, 1)[0].isdigit() else type(dt)
  170. out = out_type()
  171. for key, val in dt.items():
  172. path = key.split(delimiter)
  173. item = out
  174. for i, k in enumerate(path[:-1]):
  175. next_type = list if path[i + 1].isdigit() else dict
  176. if isinstance(item, dict):
  177. item = item.setdefault(k, next_type())
  178. elif isinstance(item, list):
  179. if int(k) >= len(item):
  180. item.append(next_type())
  181. assert int(k) == len(item) - 1
  182. item = item[int(k)]
  183. if isinstance(item, dict):
  184. item[path[-1]] = val
  185. elif isinstance(item, list):
  186. item.append(val)
  187. assert int(path[-1]) == len(item) - 1
  188. return out
  189. @Deprecated
  190. def unflattened_lookup(
  191. flat_key: str, lookup: Union[Mapping, Sequence], delimiter: str = "/", **kwargs
  192. ) -> Union[Mapping, Sequence]:
  193. """
  194. Unflatten `flat_key` and iteratively look up in `lookup`. E.g.
  195. `flat_key="a/0/b"` will try to return `lookup["a"][0]["b"]`.
  196. """
  197. if flat_key in lookup:
  198. return lookup[flat_key]
  199. keys = deque(flat_key.split(delimiter))
  200. base = lookup
  201. while keys:
  202. key = keys.popleft()
  203. try:
  204. if isinstance(base, Mapping):
  205. base = base[key]
  206. elif isinstance(base, Sequence):
  207. base = base[int(key)]
  208. else:
  209. raise KeyError()
  210. except KeyError as e:
  211. if "default" in kwargs:
  212. return kwargs["default"]
  213. raise e
  214. return base