traversal.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. from __future__ import annotations
  2. from typing import Iterator
  3. from .basic import Basic
  4. from .sorting import ordered
  5. from .sympify import sympify
  6. from sympy.utilities.iterables import iterable
  7. def iterargs(expr):
  8. """Yield the args of a Basic object in a breadth-first traversal.
  9. Depth-traversal stops if `arg.args` is either empty or is not
  10. an iterable.
  11. Examples
  12. ========
  13. >>> from sympy import Integral, Function
  14. >>> from sympy.abc import x
  15. >>> f = Function('f')
  16. >>> from sympy.core.traversal import iterargs
  17. >>> list(iterargs(Integral(f(x), (f(x), 1))))
  18. [Integral(f(x), (f(x), 1)), f(x), (f(x), 1), x, f(x), 1, x]
  19. See Also
  20. ========
  21. iterfreeargs, preorder_traversal
  22. """
  23. args = [expr]
  24. for i in args:
  25. yield i
  26. args.extend(i.args)
  27. def iterfreeargs(expr, _first=True):
  28. """Yield the args of a Basic object in a breadth-first traversal.
  29. Depth-traversal stops if `arg.args` is either empty or is not
  30. an iterable. The bound objects of an expression will be returned
  31. as canonical variables.
  32. Examples
  33. ========
  34. >>> from sympy import Integral, Function
  35. >>> from sympy.abc import x
  36. >>> f = Function('f')
  37. >>> from sympy.core.traversal import iterfreeargs
  38. >>> list(iterfreeargs(Integral(f(x), (f(x), 1))))
  39. [Integral(f(x), (f(x), 1)), 1]
  40. See Also
  41. ========
  42. iterargs, preorder_traversal
  43. """
  44. args = [expr]
  45. for i in args:
  46. yield i
  47. if _first and hasattr(i, 'bound_symbols'):
  48. void = i.canonical_variables.values()
  49. for i in iterfreeargs(i.as_dummy(), _first=False):
  50. if not i.has(*void):
  51. yield i
  52. args.extend(i.args)
  53. class preorder_traversal:
  54. """
  55. Do a pre-order traversal of a tree.
  56. This iterator recursively yields nodes that it has visited in a pre-order
  57. fashion. That is, it yields the current node then descends through the
  58. tree breadth-first to yield all of a node's children's pre-order
  59. traversal.
  60. For an expression, the order of the traversal depends on the order of
  61. .args, which in many cases can be arbitrary.
  62. Parameters
  63. ==========
  64. node : SymPy expression
  65. The expression to traverse.
  66. keys : (default None) sort key(s)
  67. The key(s) used to sort args of Basic objects. When None, args of Basic
  68. objects are processed in arbitrary order. If key is defined, it will
  69. be passed along to ordered() as the only key(s) to use to sort the
  70. arguments; if ``key`` is simply True then the default keys of ordered
  71. will be used.
  72. Yields
  73. ======
  74. subtree : SymPy expression
  75. All of the subtrees in the tree.
  76. Examples
  77. ========
  78. >>> from sympy import preorder_traversal, symbols
  79. >>> x, y, z = symbols('x y z')
  80. The nodes are returned in the order that they are encountered unless key
  81. is given; simply passing key=True will guarantee that the traversal is
  82. unique.
  83. >>> list(preorder_traversal((x + y)*z, keys=None)) # doctest: +SKIP
  84. [z*(x + y), z, x + y, y, x]
  85. >>> list(preorder_traversal((x + y)*z, keys=True))
  86. [z*(x + y), z, x + y, x, y]
  87. """
  88. def __init__(self, node, keys=None):
  89. self._skip_flag = False
  90. self._pt = self._preorder_traversal(node, keys)
  91. def _preorder_traversal(self, node, keys):
  92. yield node
  93. if self._skip_flag:
  94. self._skip_flag = False
  95. return
  96. if isinstance(node, Basic):
  97. if not keys and hasattr(node, '_argset'):
  98. # LatticeOp keeps args as a set. We should use this if we
  99. # don't care about the order, to prevent unnecessary sorting.
  100. args = node._argset
  101. else:
  102. args = node.args
  103. if keys:
  104. if keys != True:
  105. args = ordered(args, keys, default=False)
  106. else:
  107. args = ordered(args)
  108. for arg in args:
  109. yield from self._preorder_traversal(arg, keys)
  110. elif iterable(node):
  111. for item in node:
  112. yield from self._preorder_traversal(item, keys)
  113. def skip(self):
  114. """
  115. Skip yielding current node's (last yielded node's) subtrees.
  116. Examples
  117. ========
  118. >>> from sympy import preorder_traversal, symbols
  119. >>> x, y, z = symbols('x y z')
  120. >>> pt = preorder_traversal((x + y*z)*z)
  121. >>> for i in pt:
  122. ... print(i)
  123. ... if i == x + y*z:
  124. ... pt.skip()
  125. z*(x + y*z)
  126. z
  127. x + y*z
  128. """
  129. self._skip_flag = True
  130. def __next__(self):
  131. return next(self._pt)
  132. def __iter__(self) -> Iterator[Basic]:
  133. return self
  134. def use(expr, func, level=0, args=(), kwargs={}):
  135. """
  136. Use ``func`` to transform ``expr`` at the given level.
  137. Examples
  138. ========
  139. >>> from sympy import use, expand
  140. >>> from sympy.abc import x, y
  141. >>> f = (x + y)**2*x + 1
  142. >>> use(f, expand, level=2)
  143. x*(x**2 + 2*x*y + y**2) + 1
  144. >>> expand(f)
  145. x**3 + 2*x**2*y + x*y**2 + 1
  146. """
  147. def _use(expr, level):
  148. if not level:
  149. return func(expr, *args, **kwargs)
  150. else:
  151. if expr.is_Atom:
  152. return expr
  153. else:
  154. level -= 1
  155. _args = [_use(arg, level) for arg in expr.args]
  156. return expr.__class__(*_args)
  157. return _use(sympify(expr), level)
  158. def walk(e, *target):
  159. """Iterate through the args that are the given types (target) and
  160. return a list of the args that were traversed; arguments
  161. that are not of the specified types are not traversed.
  162. Examples
  163. ========
  164. >>> from sympy.core.traversal import walk
  165. >>> from sympy import Min, Max
  166. >>> from sympy.abc import x, y, z
  167. >>> list(walk(Min(x, Max(y, Min(1, z))), Min))
  168. [Min(x, Max(y, Min(1, z)))]
  169. >>> list(walk(Min(x, Max(y, Min(1, z))), Min, Max))
  170. [Min(x, Max(y, Min(1, z))), Max(y, Min(1, z)), Min(1, z)]
  171. See Also
  172. ========
  173. bottom_up
  174. """
  175. if isinstance(e, target):
  176. yield e
  177. for i in e.args:
  178. yield from walk(i, *target)
  179. def bottom_up(rv, F, atoms=False, nonbasic=False):
  180. """Apply ``F`` to all expressions in an expression tree from the
  181. bottom up. If ``atoms`` is True, apply ``F`` even if there are no args;
  182. if ``nonbasic`` is True, try to apply ``F`` to non-Basic objects.
  183. """
  184. args = getattr(rv, 'args', None)
  185. if args is not None:
  186. if args:
  187. args = tuple([bottom_up(a, F, atoms, nonbasic) for a in args])
  188. if args != rv.args:
  189. rv = rv.func(*args)
  190. rv = F(rv)
  191. elif atoms:
  192. rv = F(rv)
  193. else:
  194. if nonbasic:
  195. try:
  196. rv = F(rv)
  197. except TypeError:
  198. pass
  199. return rv
  200. def postorder_traversal(node, keys=None):
  201. """
  202. Do a postorder traversal of a tree.
  203. This generator recursively yields nodes that it has visited in a postorder
  204. fashion. That is, it descends through the tree depth-first to yield all of
  205. a node's children's postorder traversal before yielding the node itself.
  206. Parameters
  207. ==========
  208. node : SymPy expression
  209. The expression to traverse.
  210. keys : (default None) sort key(s)
  211. The key(s) used to sort args of Basic objects. When None, args of Basic
  212. objects are processed in arbitrary order. If key is defined, it will
  213. be passed along to ordered() as the only key(s) to use to sort the
  214. arguments; if ``key`` is simply True then the default keys of
  215. ``ordered`` will be used (node count and default_sort_key).
  216. Yields
  217. ======
  218. subtree : SymPy expression
  219. All of the subtrees in the tree.
  220. Examples
  221. ========
  222. >>> from sympy import postorder_traversal
  223. >>> from sympy.abc import w, x, y, z
  224. The nodes are returned in the order that they are encountered unless key
  225. is given; simply passing key=True will guarantee that the traversal is
  226. unique.
  227. >>> list(postorder_traversal(w + (x + y)*z)) # doctest: +SKIP
  228. [z, y, x, x + y, z*(x + y), w, w + z*(x + y)]
  229. >>> list(postorder_traversal(w + (x + y)*z, keys=True))
  230. [w, z, x, y, x + y, z*(x + y), w + z*(x + y)]
  231. """
  232. if isinstance(node, Basic):
  233. args = node.args
  234. if keys:
  235. if keys != True:
  236. args = ordered(args, keys, default=False)
  237. else:
  238. args = ordered(args)
  239. for arg in args:
  240. yield from postorder_traversal(arg, keys)
  241. elif iterable(node):
  242. for item in node:
  243. yield from postorder_traversal(item, keys)
  244. yield node