| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234 |
- """ Generic Unification algorithm for expression trees with lists of children
- This implementation is a direct translation of
- Artificial Intelligence: A Modern Approach by Stuart Russel and Peter Norvig
- Second edition, section 9.2, page 276
- It is modified in the following ways:
- 1. We allow associative and commutative Compound expressions. This results in
- combinatorial blowup.
- 2. We explore the tree lazily.
- 3. We provide generic interfaces to symbolic algebra libraries in Python.
- A more traditional version can be found here
- http://aima.cs.berkeley.edu/python/logic.html
- """
- from sympy.utilities.iterables import kbins
- class Compound:
- """ A little class to represent an interior node in the tree
- This is analogous to SymPy.Basic for non-Atoms
- """
- def __init__(self, op, args):
- self.op = op
- self.args = args
- def __eq__(self, other):
- return (type(self) is type(other) and self.op == other.op and
- self.args == other.args)
- def __hash__(self):
- return hash((type(self), self.op, self.args))
- def __str__(self):
- return "%s[%s]" % (str(self.op), ', '.join(map(str, self.args)))
- class Variable:
- """ A Wild token """
- def __init__(self, arg):
- self.arg = arg
- def __eq__(self, other):
- return type(self) is type(other) and self.arg == other.arg
- def __hash__(self):
- return hash((type(self), self.arg))
- def __str__(self):
- return "Variable(%s)" % str(self.arg)
- class CondVariable:
- """ A wild token that matches conditionally.
- arg - a wild token.
- valid - an additional constraining function on a match.
- """
- def __init__(self, arg, valid):
- self.arg = arg
- self.valid = valid
- def __eq__(self, other):
- return (type(self) is type(other) and
- self.arg == other.arg and
- self.valid == other.valid)
- def __hash__(self):
- return hash((type(self), self.arg, self.valid))
- def __str__(self):
- return "CondVariable(%s)" % str(self.arg)
- def unify(x, y, s=None, **fns):
- """ Unify two expressions.
- Parameters
- ==========
- x, y - expression trees containing leaves, Compounds and Variables.
- s - a mapping of variables to subtrees.
- Returns
- =======
- lazy sequence of mappings {Variable: subtree}
- Examples
- ========
- >>> from sympy.unify.core import unify, Compound, Variable
- >>> expr = Compound("Add", ("x", "y"))
- >>> pattern = Compound("Add", ("x", Variable("a")))
- >>> next(unify(expr, pattern, {}))
- {Variable(a): 'y'}
- """
- s = s or {}
- if x == y:
- yield s
- elif isinstance(x, (Variable, CondVariable)):
- yield from unify_var(x, y, s, **fns)
- elif isinstance(y, (Variable, CondVariable)):
- yield from unify_var(y, x, s, **fns)
- elif isinstance(x, Compound) and isinstance(y, Compound):
- is_commutative = fns.get('is_commutative', lambda x: False)
- is_associative = fns.get('is_associative', lambda x: False)
- for sop in unify(x.op, y.op, s, **fns):
- if is_associative(x) and is_associative(y):
- a, b = (x, y) if len(x.args) < len(y.args) else (y, x)
- if is_commutative(x) and is_commutative(y):
- combs = allcombinations(a.args, b.args, 'commutative')
- else:
- combs = allcombinations(a.args, b.args, 'associative')
- for aaargs, bbargs in combs:
- aa = [unpack(Compound(a.op, arg)) for arg in aaargs]
- bb = [unpack(Compound(b.op, arg)) for arg in bbargs]
- yield from unify(aa, bb, sop, **fns)
- elif len(x.args) == len(y.args):
- yield from unify(x.args, y.args, sop, **fns)
- elif is_args(x) and is_args(y) and len(x) == len(y):
- if len(x) == 0:
- yield s
- else:
- for shead in unify(x[0], y[0], s, **fns):
- yield from unify(x[1:], y[1:], shead, **fns)
- def unify_var(var, x, s, **fns):
- if var in s:
- yield from unify(s[var], x, s, **fns)
- elif occur_check(var, x):
- pass
- elif isinstance(var, CondVariable) and var.valid(x):
- yield assoc(s, var, x)
- elif isinstance(var, Variable):
- yield assoc(s, var, x)
- def occur_check(var, x):
- """ var occurs in subtree owned by x? """
- if var == x:
- return True
- elif isinstance(x, Compound):
- return occur_check(var, x.args)
- elif is_args(x):
- if any(occur_check(var, xi) for xi in x): return True
- return False
- def assoc(d, key, val):
- """ Return copy of d with key associated to val """
- d = d.copy()
- d[key] = val
- return d
- def is_args(x):
- """ Is x a traditional iterable? """
- return type(x) in (tuple, list, set)
- def unpack(x):
- if isinstance(x, Compound) and len(x.args) == 1:
- return x.args[0]
- else:
- return x
- def allcombinations(A, B, ordered):
- """
- Restructure A and B to have the same number of elements.
- Parameters
- ==========
- ordered must be either 'commutative' or 'associative'.
- A and B can be rearranged so that the larger of the two lists is
- reorganized into smaller sublists.
- Examples
- ========
- >>> from sympy.unify.core import allcombinations
- >>> for x in allcombinations((1, 2, 3), (5, 6), 'associative'): print(x)
- (((1,), (2, 3)), ((5,), (6,)))
- (((1, 2), (3,)), ((5,), (6,)))
- >>> for x in allcombinations((1, 2, 3), (5, 6), 'commutative'): print(x)
- (((1,), (2, 3)), ((5,), (6,)))
- (((1, 2), (3,)), ((5,), (6,)))
- (((1,), (3, 2)), ((5,), (6,)))
- (((1, 3), (2,)), ((5,), (6,)))
- (((2,), (1, 3)), ((5,), (6,)))
- (((2, 1), (3,)), ((5,), (6,)))
- (((2,), (3, 1)), ((5,), (6,)))
- (((2, 3), (1,)), ((5,), (6,)))
- (((3,), (1, 2)), ((5,), (6,)))
- (((3, 1), (2,)), ((5,), (6,)))
- (((3,), (2, 1)), ((5,), (6,)))
- (((3, 2), (1,)), ((5,), (6,)))
- """
- if ordered == "commutative":
- ordered = 11
- if ordered == "associative":
- ordered = None
- sm, bg = (A, B) if len(A) < len(B) else (B, A)
- for part in kbins(list(range(len(bg))), len(sm), ordered=ordered):
- if bg == B:
- yield tuple((a,) for a in A), partition(B, part)
- else:
- yield partition(A, part), tuple((b,) for b in B)
- def partition(it, part):
- """ Partition a tuple/list into pieces defined by indices.
- Examples
- ========
- >>> from sympy.unify.core import partition
- >>> partition((10, 20, 30, 40), [[0, 1, 2], [3]])
- ((10, 20, 30), (40,))
- """
- return type(it)([index(it, ind) for ind in part])
- def index(it, ind):
- """ Fancy indexing into an indexable iterable (tuple, list).
- Examples
- ========
- >>> from sympy.unify.core import index
- >>> index([10, 20, 30], (1, 2, 0))
- [20, 30, 10]
- """
- return type(it)([it[i] for i in ind])
|