| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- import copy
- from collections import deque
- from collections.abc import Mapping, Sequence
- from typing import Dict, List, Optional, TypeVar, Union
- from ray.util.annotations import Deprecated
- T = TypeVar("T")
- @Deprecated
- def merge_dicts(d1: dict, d2: dict) -> dict:
- """
- Args:
- d1 (dict): Dict 1.
- d2 (dict): Dict 2.
- Returns:
- dict: A new dict that is d1 and d2 deep merged.
- """
- merged = copy.deepcopy(d1)
- deep_update(merged, d2, True, [])
- return merged
- @Deprecated
- def deep_update(
- original: dict,
- new_dict: dict,
- new_keys_allowed: bool = False,
- allow_new_subkey_list: Optional[List[str]] = None,
- override_all_if_type_changes: Optional[List[str]] = None,
- override_all_key_list: Optional[List[str]] = None,
- ) -> dict:
- """Updates original dict with values from new_dict recursively.
- If new key is introduced in new_dict, then if new_keys_allowed is not
- True, an error will be thrown. Further, for sub-dicts, if the key is
- in the allow_new_subkey_list, then new subkeys can be introduced.
- Args:
- original: Dictionary with default values.
- new_dict: Dictionary with values to be updated
- new_keys_allowed: Whether new keys are allowed.
- allow_new_subkey_list: List of keys that
- correspond to dict values where new subkeys can be introduced.
- This is only at the top level.
- override_all_if_type_changes: List of top level
- keys with value=dict, for which we always simply override the
- entire value (dict), iff the "type" key in that value dict changes.
- override_all_key_list: List of top level keys
- for which we override the entire value if the key is in the new_dict.
- """
- allow_new_subkey_list = allow_new_subkey_list or []
- override_all_if_type_changes = override_all_if_type_changes or []
- override_all_key_list = override_all_key_list or []
- for k, value in new_dict.items():
- if k not in original and not new_keys_allowed:
- raise Exception("Unknown config parameter `{}` ".format(k))
- # Both orginal value and new one are dicts.
- if (
- isinstance(original.get(k), dict)
- and isinstance(value, dict)
- and k not in override_all_key_list
- ):
- # Check old type vs old one. If different, override entire value.
- if (
- k in override_all_if_type_changes
- and "type" in value
- and "type" in original[k]
- and value["type"] != original[k]["type"]
- ):
- original[k] = value
- # Allowed key -> ok to add new subkeys.
- elif k in allow_new_subkey_list:
- deep_update(
- original[k],
- value,
- True,
- override_all_key_list=override_all_key_list,
- )
- # Non-allowed key.
- else:
- deep_update(
- original[k],
- value,
- new_keys_allowed,
- override_all_key_list=override_all_key_list,
- )
- # Original value not a dict OR new value not a dict:
- # Override entire value.
- else:
- original[k] = value
- return original
- @Deprecated
- def flatten_dict(
- dt: Dict,
- delimiter: str = "/",
- prevent_delimiter: bool = False,
- flatten_list: bool = False,
- ):
- """Flatten dict.
- Output and input are of the same dict type.
- Input dict remains the same after the operation.
- """
- def _raise_delimiter_exception():
- raise ValueError(
- f"Found delimiter `{delimiter}` in key when trying to flatten "
- f"array. Please avoid using the delimiter in your specification."
- )
- dt = copy.copy(dt)
- if prevent_delimiter and any(delimiter in key for key in dt):
- # Raise if delimiter is any of the keys
- _raise_delimiter_exception()
- while_check = (dict, list) if flatten_list else dict
- while any(isinstance(v, while_check) for v in dt.values()):
- remove = []
- add = {}
- for key, value in dt.items():
- if isinstance(value, dict):
- for subkey, v in value.items():
- if prevent_delimiter and delimiter in subkey:
- # Raise if delimiter is in any of the subkeys
- _raise_delimiter_exception()
- add[delimiter.join([key, str(subkey)])] = v
- remove.append(key)
- elif flatten_list and isinstance(value, list):
- for i, v in enumerate(value):
- if prevent_delimiter and delimiter in subkey:
- # Raise if delimiter is in any of the subkeys
- _raise_delimiter_exception()
- add[delimiter.join([key, str(i)])] = v
- remove.append(key)
- dt.update(add)
- for k in remove:
- del dt[k]
- return dt
- @Deprecated
- def unflatten_dict(dt: Dict[str, T], delimiter: str = "/") -> Dict[str, T]:
- """Unflatten dict. Does not support unflattening lists."""
- dict_type = type(dt)
- out = dict_type()
- for key, val in dt.items():
- path = key.split(delimiter)
- item = out
- for k in path[:-1]:
- item = item.setdefault(k, dict_type())
- if not isinstance(item, dict_type):
- raise TypeError(
- f"Cannot unflatten dict due the key '{key}' "
- f"having a parent key '{k}', which value is not "
- f"of type {dict_type} (got {type(item)}). "
- "Change the key names to resolve the conflict."
- )
- item[path[-1]] = val
- return out
- @Deprecated
- def unflatten_list_dict(dt: Dict[str, T], delimiter: str = "/") -> Dict[str, T]:
- """Unflatten nested dict and list.
- This function now has some limitations:
- (1) The keys of dt must be str.
- (2) If unflattened dt (the result) contains list, the index order must be
- ascending when accessing dt. Otherwise, this function will throw
- AssertionError.
- (3) The unflattened dt (the result) shouldn't contain dict with number
- keys.
- Be careful to use this function. If you want to improve this function,
- please also improve the unit test. See #14487 for more details.
- Args:
- dt: Flattened dictionary that is originally nested by multiple
- list and dict.
- delimiter: Delimiter of keys.
- Example:
- >>> dt = {"aaa/0/bb": 12, "aaa/1/cc": 56, "aaa/1/dd": 92}
- >>> unflatten_list_dict(dt)
- {'aaa': [{'bb': 12}, {'cc': 56, 'dd': 92}]}
- """
- out_type = list if list(dt)[0].split(delimiter, 1)[0].isdigit() else type(dt)
- out = out_type()
- for key, val in dt.items():
- path = key.split(delimiter)
- item = out
- for i, k in enumerate(path[:-1]):
- next_type = list if path[i + 1].isdigit() else dict
- if isinstance(item, dict):
- item = item.setdefault(k, next_type())
- elif isinstance(item, list):
- if int(k) >= len(item):
- item.append(next_type())
- assert int(k) == len(item) - 1
- item = item[int(k)]
- if isinstance(item, dict):
- item[path[-1]] = val
- elif isinstance(item, list):
- item.append(val)
- assert int(path[-1]) == len(item) - 1
- return out
- @Deprecated
- def unflattened_lookup(
- flat_key: str, lookup: Union[Mapping, Sequence], delimiter: str = "/", **kwargs
- ) -> Union[Mapping, Sequence]:
- """
- Unflatten `flat_key` and iteratively look up in `lookup`. E.g.
- `flat_key="a/0/b"` will try to return `lookup["a"][0]["b"]`.
- """
- if flat_key in lookup:
- return lookup[flat_key]
- keys = deque(flat_key.split(delimiter))
- base = lookup
- while keys:
- key = keys.popleft()
- try:
- if isinstance(base, Mapping):
- base = base[key]
- elif isinstance(base, Sequence):
- base = base[int(key)]
- else:
- raise KeyError()
- except KeyError as e:
- if "default" in kwargs:
- return kwargs["default"]
- raise e
- return base
|