collections.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. from typing import Dict, TypeVar
  2. K = TypeVar("K")
  3. def collapse_transitive_map(d: Dict[K, K]) -> Dict[K, K]:
  4. """Collapse transitive mappings in a dictionary. Given a mapping like
  5. {a: b, b: c, c: d}, returns {a: d}, removing intermediate b -> c, c -> d.
  6. Only keeps mappings where the key is NOT a value in another mapping (i.e., chain starting points).
  7. Args:
  8. d: Dictionary representing a mapping
  9. Returns:
  10. Dictionary with all transitive mappings collapsed, keeping only KV-pairs,
  11. such that K and V are starting and terminal point of a chain
  12. Examples:
  13. >>> collapse_transitive_map({"a": "b", "b": "c", "c": "d"})
  14. {'a': 'd'}
  15. >>> collapse_transitive_map({"a": "b", "x": "y"})
  16. {'a': 'b', 'x': 'y'}
  17. """
  18. if not d:
  19. return {}
  20. collapsed = {}
  21. values_set = set(d.values())
  22. for k in d:
  23. # Skip mappings that are in the value-set, meaning that they are
  24. # part of the mapping chain (for ex, {a -> b, b -> c})
  25. if k in values_set:
  26. continue
  27. cur = k
  28. visited = {cur}
  29. # Follow the chain until we reach a key that's not in the mapping
  30. while cur in d:
  31. next = d[cur]
  32. if next in visited:
  33. raise ValueError(f"Detected a cycle in the mapping {d}")
  34. visited.add(next)
  35. cur = next
  36. collapsed[k] = cur
  37. return collapsed