match.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # mypy: allow-untyped-defs
  2. from .core import reify, unify # type: ignore[attr-defined]
  3. from .unification_tools import first, groupby # type: ignore[import]
  4. from .utils import _toposort, freeze
  5. from .variable import isvar
  6. class Dispatcher:
  7. def __init__(self, name):
  8. self.name = name
  9. self.funcs = {}
  10. self.ordering = []
  11. def add(self, signature, func):
  12. self.funcs[freeze(signature)] = func
  13. self.ordering = ordering(self.funcs)
  14. def __call__(self, *args, **kwargs):
  15. func, _ = self.resolve(args)
  16. return func(*args, **kwargs)
  17. def resolve(self, args):
  18. n = len(args)
  19. for signature in self.ordering:
  20. if len(signature) != n:
  21. continue
  22. s = unify(freeze(args), signature)
  23. if s is not False:
  24. result = self.funcs[signature]
  25. return result, s
  26. raise NotImplementedError(
  27. "No match found. \nKnown matches: "
  28. + str(self.ordering)
  29. + "\nInput: "
  30. + str(args)
  31. )
  32. def register(self, *signature):
  33. def _(func):
  34. self.add(signature, func)
  35. return self
  36. return _
  37. class VarDispatcher(Dispatcher):
  38. """A dispatcher that calls functions with variable names
  39. >>> # xdoctest: +SKIP
  40. >>> d = VarDispatcher("d")
  41. >>> x = var("x")
  42. >>> @d.register("inc", x)
  43. ... def f(x):
  44. ... return x + 1
  45. >>> @d.register("double", x)
  46. ... def f(x):
  47. ... return x * 2
  48. >>> d("inc", 10)
  49. 11
  50. >>> d("double", 10)
  51. 20
  52. """
  53. def __call__(self, *args, **kwargs):
  54. func, s = self.resolve(args)
  55. d = {k.token: v for k, v in s.items()}
  56. return func(**d)
  57. global_namespace = {} # type: ignore[var-annotated]
  58. def match(*signature, **kwargs):
  59. namespace = kwargs.get("namespace", global_namespace)
  60. dispatcher = kwargs.get("Dispatcher", Dispatcher)
  61. def _(func):
  62. name = func.__name__
  63. if name not in namespace:
  64. namespace[name] = dispatcher(name)
  65. d = namespace[name]
  66. d.add(signature, func)
  67. return d
  68. return _
  69. def supercedes(a, b):
  70. """``a`` is a more specific match than ``b``"""
  71. if isvar(b) and not isvar(a):
  72. return True
  73. s = unify(a, b)
  74. if s is False:
  75. return False
  76. s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
  77. if reify(a, s) == a:
  78. return True
  79. if reify(b, s) == b:
  80. return False
  81. # Taken from multipledispatch
  82. def edge(a, b, tie_breaker=hash):
  83. """A should be checked before B
  84. Tie broken by tie_breaker, defaults to ``hash``
  85. """
  86. if supercedes(a, b):
  87. if supercedes(b, a):
  88. return tie_breaker(a) > tie_breaker(b)
  89. else:
  90. return True
  91. return False
  92. # Taken from multipledispatch
  93. def ordering(signatures):
  94. """A sane ordering of signatures to check, first to last
  95. Topological sort of edges as given by ``edge`` and ``supercedes``
  96. """
  97. signatures = list(map(tuple, signatures))
  98. edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
  99. edges = groupby(first, edges)
  100. for s in signatures:
  101. if s not in edges:
  102. edges[s] = []
  103. edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment]
  104. return _toposort(edges)