utils.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # mypy: allow-untyped-defs
  2. __all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"]
  3. def hashable(x):
  4. try:
  5. hash(x)
  6. return True
  7. except TypeError:
  8. return False
  9. def transitive_get(key, d):
  10. """Transitive dict.get
  11. >>> d = {1: 2, 2: 3, 3: 4}
  12. >>> d.get(1)
  13. 2
  14. >>> transitive_get(1, d)
  15. 4
  16. """
  17. while hashable(key) and key in d:
  18. key = d[key]
  19. return key
  20. def raises(err, lamda): # codespell:ignore lamda
  21. try:
  22. lamda() # codespell:ignore lamda
  23. return False
  24. except err:
  25. return True
  26. # Taken from theano/theano/gof/sched.py
  27. # Avoids licensing issues because this was written by Matthew Rocklin
  28. def _toposort(edges):
  29. """Topological sort algorithm by Kahn [1] - O(nodes + vertices)
  30. inputs:
  31. edges - a dict of the form {a: {b, c}} where b and c depend on a
  32. outputs:
  33. L - an ordered list of nodes that satisfy the dependencies of edges
  34. >>> # xdoctest: +SKIP
  35. >>> _toposort({1: (2, 3), 2: (3,)})
  36. [1, 2, 3]
  37. Closely follows the wikipedia page [2]
  38. [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
  39. Communications of the ACM
  40. [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
  41. """
  42. incoming_edges = reverse_dict(edges)
  43. incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
  44. S = {v for v in edges if v not in incoming_edges}
  45. L = []
  46. while S:
  47. n = S.pop()
  48. L.append(n)
  49. for m in edges.get(n, ()):
  50. if n not in incoming_edges[m]:
  51. raise AssertionError(f"Expected {n} in incoming_edges[{m}]")
  52. incoming_edges[m].remove(n)
  53. if not incoming_edges[m]:
  54. S.add(m)
  55. if any(incoming_edges.get(v) for v in edges):
  56. raise ValueError("Input has cycles")
  57. return L
  58. def reverse_dict(d):
  59. """Reverses direction of dependence dict
  60. >>> d = {"a": (1, 2), "b": (2, 3), "c": ()}
  61. >>> reverse_dict(d) # doctest: +SKIP
  62. {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
  63. :note: dict order are not deterministic. As we iterate on the
  64. input dict, it make the output of this function depend on the
  65. dict order. So this function output order should be considered
  66. as undeterministic.
  67. """
  68. result = {} # type: ignore[var-annotated]
  69. for key in d:
  70. for val in d[key]:
  71. # pyrefly: ignore [unsupported-operation]
  72. result[val] = result.get(val, ()) + (key,)
  73. return result
  74. def xfail(func):
  75. try:
  76. func()
  77. raise Exception("XFailed test passed") # pragma:nocover # noqa: TRY002
  78. except Exception:
  79. pass
  80. def freeze(d):
  81. """Freeze container to hashable form
  82. >>> freeze(1)
  83. 1
  84. >>> freeze([1, 2])
  85. (1, 2)
  86. >>> freeze({1: 2}) # doctest: +SKIP
  87. frozenset([(1, 2)])
  88. """
  89. if isinstance(d, dict):
  90. return frozenset(map(freeze, d.items()))
  91. if isinstance(d, set):
  92. return frozenset(map(freeze, d))
  93. if isinstance(d, (tuple, list)):
  94. return tuple(map(freeze, d))
  95. return d