epathtools.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. """Tools for manipulation of expressions using paths. """
  2. from sympy.core import Basic
  3. class EPath:
  4. r"""
  5. Manipulate expressions using paths.
  6. EPath grammar in EBNF notation::
  7. literal ::= /[A-Za-z_][A-Za-z_0-9]*/
  8. number ::= /-?\d+/
  9. type ::= literal
  10. attribute ::= literal "?"
  11. all ::= "*"
  12. slice ::= "[" number? (":" number? (":" number?)?)? "]"
  13. range ::= all | slice
  14. query ::= (type | attribute) ("|" (type | attribute))*
  15. selector ::= range | query range?
  16. path ::= "/" selector ("/" selector)*
  17. See the docstring of the epath() function.
  18. """
  19. __slots__ = ("_path", "_epath")
  20. def __new__(cls, path):
  21. """Construct new EPath. """
  22. if isinstance(path, EPath):
  23. return path
  24. if not path:
  25. raise ValueError("empty EPath")
  26. _path = path
  27. if path[0] == '/':
  28. path = path[1:]
  29. else:
  30. raise NotImplementedError("non-root EPath")
  31. epath = []
  32. for selector in path.split('/'):
  33. selector = selector.strip()
  34. if not selector:
  35. raise ValueError("empty selector")
  36. index = 0
  37. for c in selector:
  38. if c.isalnum() or c in ('_', '|', '?'):
  39. index += 1
  40. else:
  41. break
  42. attrs = []
  43. types = []
  44. if index:
  45. elements = selector[:index]
  46. selector = selector[index:]
  47. for element in elements.split('|'):
  48. element = element.strip()
  49. if not element:
  50. raise ValueError("empty element")
  51. if element.endswith('?'):
  52. attrs.append(element[:-1])
  53. else:
  54. types.append(element)
  55. span = None
  56. if selector == '*':
  57. pass
  58. else:
  59. if selector.startswith('['):
  60. try:
  61. i = selector.index(']')
  62. except ValueError:
  63. raise ValueError("expected ']', got EOL")
  64. _span, span = selector[1:i], []
  65. if ':' not in _span:
  66. span = int(_span)
  67. else:
  68. for elt in _span.split(':', 3):
  69. if not elt:
  70. span.append(None)
  71. else:
  72. span.append(int(elt))
  73. span = slice(*span)
  74. selector = selector[i + 1:]
  75. if selector:
  76. raise ValueError("trailing characters in selector")
  77. epath.append((attrs, types, span))
  78. obj = object.__new__(cls)
  79. obj._path = _path
  80. obj._epath = epath
  81. return obj
  82. def __repr__(self):
  83. return "%s(%r)" % (self.__class__.__name__, self._path)
  84. def _get_ordered_args(self, expr):
  85. """Sort ``expr.args`` using printing order. """
  86. if expr.is_Add:
  87. return expr.as_ordered_terms()
  88. elif expr.is_Mul:
  89. return expr.as_ordered_factors()
  90. else:
  91. return expr.args
  92. def _hasattrs(self, expr, attrs) -> bool:
  93. """Check if ``expr`` has any of ``attrs``. """
  94. return all(hasattr(expr, attr) for attr in attrs)
  95. def _hastypes(self, expr, types):
  96. """Check if ``expr`` is any of ``types``. """
  97. _types = [ cls.__name__ for cls in expr.__class__.mro() ]
  98. return bool(set(_types).intersection(types))
  99. def _has(self, expr, attrs, types):
  100. """Apply ``_hasattrs`` and ``_hastypes`` to ``expr``. """
  101. if not (attrs or types):
  102. return True
  103. if attrs and self._hasattrs(expr, attrs):
  104. return True
  105. if types and self._hastypes(expr, types):
  106. return True
  107. return False
  108. def apply(self, expr, func, args=None, kwargs=None):
  109. """
  110. Modify parts of an expression selected by a path.
  111. Examples
  112. ========
  113. >>> from sympy.simplify.epathtools import EPath
  114. >>> from sympy import sin, cos, E
  115. >>> from sympy.abc import x, y, z, t
  116. >>> path = EPath("/*/[0]/Symbol")
  117. >>> expr = [((x, 1), 2), ((3, y), z)]
  118. >>> path.apply(expr, lambda expr: expr**2)
  119. [((x**2, 1), 2), ((3, y**2), z)]
  120. >>> path = EPath("/*/*/Symbol")
  121. >>> expr = t + sin(x + 1) + cos(x + y + E)
  122. >>> path.apply(expr, lambda expr: 2*expr)
  123. t + sin(2*x + 1) + cos(2*x + 2*y + E)
  124. """
  125. def _apply(path, expr, func):
  126. if not path:
  127. return func(expr)
  128. else:
  129. selector, path = path[0], path[1:]
  130. attrs, types, span = selector
  131. if isinstance(expr, Basic):
  132. if not expr.is_Atom:
  133. args, basic = self._get_ordered_args(expr), True
  134. else:
  135. return expr
  136. elif hasattr(expr, '__iter__'):
  137. args, basic = expr, False
  138. else:
  139. return expr
  140. args = list(args)
  141. if span is not None:
  142. if isinstance(span, slice):
  143. indices = range(*span.indices(len(args)))
  144. else:
  145. indices = [span]
  146. else:
  147. indices = range(len(args))
  148. for i in indices:
  149. try:
  150. arg = args[i]
  151. except IndexError:
  152. continue
  153. if self._has(arg, attrs, types):
  154. args[i] = _apply(path, arg, func)
  155. if basic:
  156. return expr.func(*args)
  157. else:
  158. return expr.__class__(args)
  159. _args, _kwargs = args or (), kwargs or {}
  160. _func = lambda expr: func(expr, *_args, **_kwargs)
  161. return _apply(self._epath, expr, _func)
  162. def select(self, expr):
  163. """
  164. Retrieve parts of an expression selected by a path.
  165. Examples
  166. ========
  167. >>> from sympy.simplify.epathtools import EPath
  168. >>> from sympy import sin, cos, E
  169. >>> from sympy.abc import x, y, z, t
  170. >>> path = EPath("/*/[0]/Symbol")
  171. >>> expr = [((x, 1), 2), ((3, y), z)]
  172. >>> path.select(expr)
  173. [x, y]
  174. >>> path = EPath("/*/*/Symbol")
  175. >>> expr = t + sin(x + 1) + cos(x + y + E)
  176. >>> path.select(expr)
  177. [x, x, y]
  178. """
  179. result = []
  180. def _select(path, expr):
  181. if not path:
  182. result.append(expr)
  183. else:
  184. selector, path = path[0], path[1:]
  185. attrs, types, span = selector
  186. if isinstance(expr, Basic):
  187. args = self._get_ordered_args(expr)
  188. elif hasattr(expr, '__iter__'):
  189. args = expr
  190. else:
  191. return
  192. if span is not None:
  193. if isinstance(span, slice):
  194. args = args[span]
  195. else:
  196. try:
  197. args = [args[span]]
  198. except IndexError:
  199. return
  200. for arg in args:
  201. if self._has(arg, attrs, types):
  202. _select(path, arg)
  203. _select(self._epath, expr)
  204. return result
  205. def epath(path, expr=None, func=None, args=None, kwargs=None):
  206. r"""
  207. Manipulate parts of an expression selected by a path.
  208. Explanation
  209. ===========
  210. This function allows to manipulate large nested expressions in single
  211. line of code, utilizing techniques to those applied in XML processing
  212. standards (e.g. XPath).
  213. If ``func`` is ``None``, :func:`epath` retrieves elements selected by
  214. the ``path``. Otherwise it applies ``func`` to each matching element.
  215. Note that it is more efficient to create an EPath object and use the select
  216. and apply methods of that object, since this will compile the path string
  217. only once. This function should only be used as a convenient shortcut for
  218. interactive use.
  219. This is the supported syntax:
  220. * select all: ``/*``
  221. Equivalent of ``for arg in args:``.
  222. * select slice: ``/[0]`` or ``/[1:5]`` or ``/[1:5:2]``
  223. Supports standard Python's slice syntax.
  224. * select by type: ``/list`` or ``/list|tuple``
  225. Emulates ``isinstance()``.
  226. * select by attribute: ``/__iter__?``
  227. Emulates ``hasattr()``.
  228. Parameters
  229. ==========
  230. path : str | EPath
  231. A path as a string or a compiled EPath.
  232. expr : Basic | iterable
  233. An expression or a container of expressions.
  234. func : callable (optional)
  235. A callable that will be applied to matching parts.
  236. args : tuple (optional)
  237. Additional positional arguments to ``func``.
  238. kwargs : dict (optional)
  239. Additional keyword arguments to ``func``.
  240. Examples
  241. ========
  242. >>> from sympy.simplify.epathtools import epath
  243. >>> from sympy import sin, cos, E
  244. >>> from sympy.abc import x, y, z, t
  245. >>> path = "/*/[0]/Symbol"
  246. >>> expr = [((x, 1), 2), ((3, y), z)]
  247. >>> epath(path, expr)
  248. [x, y]
  249. >>> epath(path, expr, lambda expr: expr**2)
  250. [((x**2, 1), 2), ((3, y**2), z)]
  251. >>> path = "/*/*/Symbol"
  252. >>> expr = t + sin(x + 1) + cos(x + y + E)
  253. >>> epath(path, expr)
  254. [x, x, y]
  255. >>> epath(path, expr, lambda expr: 2*expr)
  256. t + sin(2*x + 1) + cos(2*x + 2*y + E)
  257. """
  258. _epath = EPath(path)
  259. if expr is None:
  260. return _epath
  261. if func is None:
  262. return _epath.select(expr)
  263. else:
  264. return _epath.apply(expr, func, args, kwargs)