matchpy_connector.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. """
  2. The objects in this module allow the usage of the MatchPy pattern matching
  3. library on SymPy expressions.
  4. """
  5. import re
  6. from typing import List, Callable, NamedTuple, Any, Dict
  7. from sympy.core.sympify import _sympify
  8. from sympy.external import import_module
  9. from sympy.functions import (log, sin, cos, tan, cot, csc, sec, erf, gamma, uppergamma)
  10. from sympy.functions.elementary.hyperbolic import acosh, asinh, atanh, acoth, acsch, asech, cosh, sinh, tanh, coth, sech, csch
  11. from sympy.functions.elementary.trigonometric import atan, acsc, asin, acot, acos, asec
  12. from sympy.functions.special.error_functions import fresnelc, fresnels, erfc, erfi, Ei
  13. from sympy.core.add import Add
  14. from sympy.core.basic import Basic
  15. from sympy.core.expr import Expr
  16. from sympy.core.mul import Mul
  17. from sympy.core.power import Pow
  18. from sympy.core.relational import (Equality, Unequality)
  19. from sympy.core.symbol import Symbol
  20. from sympy.functions.elementary.exponential import exp
  21. from sympy.integrals.integrals import Integral
  22. from sympy.printing.repr import srepr
  23. from sympy.utilities.decorator import doctest_depends_on
  24. matchpy = import_module("matchpy")
  25. __doctest_requires__ = {('*',): ['matchpy']}
  26. if matchpy:
  27. from matchpy import Operation, CommutativeOperation, AssociativeOperation, OneIdentityOperation
  28. from matchpy.expressions.functions import op_iter, create_operation_expression, op_len
  29. Operation.register(Integral)
  30. Operation.register(Pow)
  31. OneIdentityOperation.register(Pow)
  32. Operation.register(Add)
  33. OneIdentityOperation.register(Add)
  34. CommutativeOperation.register(Add)
  35. AssociativeOperation.register(Add)
  36. Operation.register(Mul)
  37. OneIdentityOperation.register(Mul)
  38. CommutativeOperation.register(Mul)
  39. AssociativeOperation.register(Mul)
  40. Operation.register(Equality)
  41. CommutativeOperation.register(Equality)
  42. Operation.register(Unequality)
  43. CommutativeOperation.register(Unequality)
  44. Operation.register(exp)
  45. Operation.register(log)
  46. Operation.register(gamma)
  47. Operation.register(uppergamma)
  48. Operation.register(fresnels)
  49. Operation.register(fresnelc)
  50. Operation.register(erf)
  51. Operation.register(Ei)
  52. Operation.register(erfc)
  53. Operation.register(erfi)
  54. Operation.register(sin)
  55. Operation.register(cos)
  56. Operation.register(tan)
  57. Operation.register(cot)
  58. Operation.register(csc)
  59. Operation.register(sec)
  60. Operation.register(sinh)
  61. Operation.register(cosh)
  62. Operation.register(tanh)
  63. Operation.register(coth)
  64. Operation.register(csch)
  65. Operation.register(sech)
  66. Operation.register(asin)
  67. Operation.register(acos)
  68. Operation.register(atan)
  69. Operation.register(acot)
  70. Operation.register(acsc)
  71. Operation.register(asec)
  72. Operation.register(asinh)
  73. Operation.register(acosh)
  74. Operation.register(atanh)
  75. Operation.register(acoth)
  76. Operation.register(acsch)
  77. Operation.register(asech)
  78. @op_iter.register(Integral) # type: ignore
  79. def _(operation):
  80. return iter((operation._args[0],) + operation._args[1])
  81. @op_iter.register(Basic) # type: ignore
  82. def _(operation):
  83. return iter(operation._args)
  84. @op_len.register(Integral) # type: ignore
  85. def _(operation):
  86. return 1 + len(operation._args[1])
  87. @op_len.register(Basic) # type: ignore
  88. def _(operation):
  89. return len(operation._args)
  90. @create_operation_expression.register(Basic)
  91. def sympy_op_factory(old_operation, new_operands, variable_name=True):
  92. return type(old_operation)(*new_operands)
  93. if matchpy:
  94. from matchpy import Wildcard
  95. else:
  96. class Wildcard: # type: ignore
  97. def __init__(self, min_length, fixed_size, variable_name, optional):
  98. self.min_count = min_length
  99. self.fixed_size = fixed_size
  100. self.variable_name = variable_name
  101. self.optional = optional
  102. @doctest_depends_on(modules=('matchpy',))
  103. class _WildAbstract(Wildcard, Symbol):
  104. min_length: int # abstract field required in subclasses
  105. fixed_size: bool # abstract field required in subclasses
  106. def __init__(self, variable_name=None, optional=None, **assumptions):
  107. min_length = self.min_length
  108. fixed_size = self.fixed_size
  109. if optional is not None:
  110. optional = _sympify(optional)
  111. Wildcard.__init__(self, min_length, fixed_size, str(variable_name), optional)
  112. def __getstate__(self):
  113. return {
  114. "min_length": self.min_length,
  115. "fixed_size": self.fixed_size,
  116. "min_count": self.min_count,
  117. "variable_name": self.variable_name,
  118. "optional": self.optional,
  119. }
  120. def __new__(cls, variable_name=None, optional=None, **assumptions):
  121. cls._sanitize(assumptions, cls)
  122. return _WildAbstract.__xnew__(cls, variable_name, optional, **assumptions)
  123. def __getnewargs__(self):
  124. return self.variable_name, self.optional
  125. @staticmethod
  126. def __xnew__(cls, variable_name=None, optional=None, **assumptions):
  127. obj = Symbol.__xnew__(cls, variable_name, **assumptions)
  128. return obj
  129. def _hashable_content(self):
  130. if self.optional:
  131. return super()._hashable_content() + (self.min_count, self.fixed_size, self.variable_name, self.optional)
  132. else:
  133. return super()._hashable_content() + (self.min_count, self.fixed_size, self.variable_name)
  134. def __copy__(self) -> '_WildAbstract':
  135. return type(self)(variable_name=self.variable_name, optional=self.optional)
  136. def __repr__(self):
  137. return str(self)
  138. def __str__(self):
  139. return self.name
  140. @doctest_depends_on(modules=('matchpy',))
  141. class WildDot(_WildAbstract):
  142. min_length = 1
  143. fixed_size = True
  144. @doctest_depends_on(modules=('matchpy',))
  145. class WildPlus(_WildAbstract):
  146. min_length = 1
  147. fixed_size = False
  148. @doctest_depends_on(modules=('matchpy',))
  149. class WildStar(_WildAbstract):
  150. min_length = 0
  151. fixed_size = False
  152. def _get_srepr(expr):
  153. s = srepr(expr)
  154. s = re.sub(r"WildDot\('(\w+)'\)", r"\1", s)
  155. s = re.sub(r"WildPlus\('(\w+)'\)", r"*\1", s)
  156. s = re.sub(r"WildStar\('(\w+)'\)", r"*\1", s)
  157. return s
  158. class ReplacementInfo(NamedTuple):
  159. replacement: Any
  160. info: Any
  161. @doctest_depends_on(modules=('matchpy',))
  162. class Replacer:
  163. """
  164. Replacer object to perform multiple pattern matching and subexpression
  165. replacements in SymPy expressions.
  166. Examples
  167. ========
  168. Example to construct a simple first degree equation solver:
  169. >>> from sympy.utilities.matchpy_connector import WildDot, Replacer
  170. >>> from sympy import Equality, Symbol
  171. >>> x = Symbol("x")
  172. >>> a_ = WildDot("a_", optional=1)
  173. >>> b_ = WildDot("b_", optional=0)
  174. The lines above have defined two wildcards, ``a_`` and ``b_``, the
  175. coefficients of the equation `a x + b = 0`. The optional values specified
  176. indicate which expression to return in case no match is found, they are
  177. necessary in equations like `a x = 0` and `x + b = 0`.
  178. Create two constraints to make sure that ``a_`` and ``b_`` will not match
  179. any expression containing ``x``:
  180. >>> from matchpy import CustomConstraint
  181. >>> free_x_a = CustomConstraint(lambda a_: not a_.has(x))
  182. >>> free_x_b = CustomConstraint(lambda b_: not b_.has(x))
  183. Now create the rule replacer with the constraints:
  184. >>> replacer = Replacer(common_constraints=[free_x_a, free_x_b])
  185. Add the matching rule:
  186. >>> replacer.add(Equality(a_*x + b_, 0), -b_/a_)
  187. Let's try it:
  188. >>> replacer.replace(Equality(3*x + 4, 0))
  189. -4/3
  190. Notice that it will not match equations expressed with other patterns:
  191. >>> eq = Equality(3*x, 4)
  192. >>> replacer.replace(eq)
  193. Eq(3*x, 4)
  194. In order to extend the matching patterns, define another one (we also need
  195. to clear the cache, because the previous result has already been memorized
  196. and the pattern matcher will not iterate again if given the same expression)
  197. >>> replacer.add(Equality(a_*x, b_), b_/a_)
  198. >>> replacer._matcher.clear()
  199. >>> replacer.replace(eq)
  200. 4/3
  201. """
  202. def __init__(self, common_constraints: list = [], lambdify: bool = False, info: bool = False):
  203. self._matcher = matchpy.ManyToOneMatcher()
  204. self._common_constraint = common_constraints
  205. self._lambdify = lambdify
  206. self._info = info
  207. self._wildcards: Dict[str, Wildcard] = {}
  208. def _get_lambda(self, lambda_str: str) -> Callable[..., Expr]:
  209. exec("from sympy import *")
  210. return eval(lambda_str, locals())
  211. def _get_custom_constraint(self, constraint_expr: Expr, condition_template: str) -> Callable[..., Expr]:
  212. wilds = [x.name for x in constraint_expr.atoms(_WildAbstract)]
  213. lambdaargs = ', '.join(wilds)
  214. fullexpr = _get_srepr(constraint_expr)
  215. condition = condition_template.format(fullexpr)
  216. return matchpy.CustomConstraint(
  217. self._get_lambda(f"lambda {lambdaargs}: ({condition})"))
  218. def _get_custom_constraint_nonfalse(self, constraint_expr: Expr) -> Callable[..., Expr]:
  219. return self._get_custom_constraint(constraint_expr, "({}) != False")
  220. def _get_custom_constraint_true(self, constraint_expr: Expr) -> Callable[..., Expr]:
  221. return self._get_custom_constraint(constraint_expr, "({}) == True")
  222. def add(self, expr: Expr, replacement, conditions_true: List[Expr] = [],
  223. conditions_nonfalse: List[Expr] = [], info: Any = None) -> None:
  224. expr = _sympify(expr)
  225. replacement = _sympify(replacement)
  226. constraints = self._common_constraint[:]
  227. constraint_conditions_true = [
  228. self._get_custom_constraint_true(cond) for cond in conditions_true]
  229. constraint_conditions_nonfalse = [
  230. self._get_custom_constraint_nonfalse(cond) for cond in conditions_nonfalse]
  231. constraints.extend(constraint_conditions_true)
  232. constraints.extend(constraint_conditions_nonfalse)
  233. pattern = matchpy.Pattern(expr, *constraints)
  234. if self._lambdify:
  235. lambda_str = f"lambda {', '.join((x.name for x in expr.atoms(_WildAbstract)))}: {_get_srepr(replacement)}"
  236. lambda_expr = self._get_lambda(lambda_str)
  237. replacement = lambda_expr
  238. else:
  239. self._wildcards.update({str(i): i for i in expr.atoms(Wildcard)})
  240. if self._info:
  241. replacement = ReplacementInfo(replacement, info)
  242. self._matcher.add(pattern, replacement)
  243. def replace(self, expression, max_count: int = -1):
  244. # This method partly rewrites the .replace method of ManyToOneReplacer
  245. # in MatchPy.
  246. # License: https://github.com/HPAC/matchpy/blob/master/LICENSE
  247. infos = []
  248. replaced = True
  249. replace_count = 0
  250. while replaced and (max_count < 0 or replace_count < max_count):
  251. replaced = False
  252. for subexpr, pos in matchpy.preorder_iter_with_position(expression):
  253. try:
  254. replacement_data, subst = next(iter(self._matcher.match(subexpr)))
  255. if self._info:
  256. replacement = replacement_data.replacement
  257. infos.append(replacement_data.info)
  258. else:
  259. replacement = replacement_data
  260. if self._lambdify:
  261. result = replacement(**subst)
  262. else:
  263. result = replacement.xreplace({self._wildcards[k]: v for k, v in subst.items()})
  264. expression = matchpy.functions.replace(expression, pos, result)
  265. replaced = True
  266. break
  267. except StopIteration:
  268. pass
  269. replace_count += 1
  270. if self._info:
  271. return expression, infos
  272. else:
  273. return expression