algorithms.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. from sympy.core.containers import Tuple
  2. from sympy.core.numbers import oo
  3. from sympy.core.relational import (Gt, Lt)
  4. from sympy.core.symbol import (Dummy, Symbol)
  5. from sympy.functions.elementary.complexes import Abs
  6. from sympy.functions.elementary.miscellaneous import Min, Max
  7. from sympy.logic.boolalg import And
  8. from sympy.codegen.ast import (
  9. Assignment, AddAugmentedAssignment, break_, CodeBlock, Declaration, FunctionDefinition,
  10. Print, Return, Scope, While, Variable, Pointer, real
  11. )
  12. from sympy.codegen.cfunctions import isnan
  13. """ This module collects functions for constructing ASTs representing algorithms. """
  14. def newtons_method(expr, wrt, atol=1e-12, delta=None, *, rtol=4e-16, debug=False,
  15. itermax=None, counter=None, delta_fn=lambda e, x: -e/e.diff(x),
  16. cse=False, handle_nan=None,
  17. bounds=None):
  18. """ Generates an AST for Newton-Raphson method (a root-finding algorithm).
  19. Explanation
  20. ===========
  21. Returns an abstract syntax tree (AST) based on ``sympy.codegen.ast`` for Netwon's
  22. method of root-finding.
  23. Parameters
  24. ==========
  25. expr : expression
  26. wrt : Symbol
  27. With respect to, i.e. what is the variable.
  28. atol : number or expression
  29. Absolute tolerance (stopping criterion)
  30. rtol : number or expression
  31. Relative tolerance (stopping criterion)
  32. delta : Symbol
  33. Will be a ``Dummy`` if ``None``.
  34. debug : bool
  35. Whether to print convergence information during iterations
  36. itermax : number or expr
  37. Maximum number of iterations.
  38. counter : Symbol
  39. Will be a ``Dummy`` if ``None``.
  40. delta_fn: Callable[[Expr, Symbol], Expr]
  41. computes the step, default is newtons method. For e.g. Halley's method
  42. use delta_fn=lambda e, x: -2*e*e.diff(x)/(2*e.diff(x)**2 - e*e.diff(x, 2))
  43. cse: bool
  44. Perform common sub-expression elimination on delta expression
  45. handle_nan: Token
  46. How to handle occurrence of not-a-number (NaN).
  47. bounds: Optional[tuple[Expr, Expr]]
  48. Perform optimization within bounds
  49. Examples
  50. ========
  51. >>> from sympy import symbols, cos
  52. >>> from sympy.codegen.ast import Assignment
  53. >>> from sympy.codegen.algorithms import newtons_method
  54. >>> x, dx, atol = symbols('x dx atol')
  55. >>> expr = cos(x) - x**3
  56. >>> algo = newtons_method(expr, x, atol=atol, delta=dx)
  57. >>> algo.has(Assignment(dx, -expr/expr.diff(x)))
  58. True
  59. References
  60. ==========
  61. .. [1] https://en.wikipedia.org/wiki/Newton%27s_method
  62. """
  63. if delta is None:
  64. delta = Dummy()
  65. Wrapper = Scope
  66. name_d = 'delta'
  67. else:
  68. Wrapper = lambda x: x
  69. name_d = delta.name
  70. delta_expr = delta_fn(expr, wrt)
  71. if cse:
  72. from sympy.simplify.cse_main import cse
  73. cses, (red,) = cse([delta_expr.factor()])
  74. whl_bdy = [Assignment(dum, sub_e) for dum, sub_e in cses]
  75. whl_bdy += [Assignment(delta, red)]
  76. else:
  77. whl_bdy = [Assignment(delta, delta_expr)]
  78. if handle_nan is not None:
  79. whl_bdy += [While(isnan(delta), CodeBlock(handle_nan, break_))]
  80. whl_bdy += [AddAugmentedAssignment(wrt, delta)]
  81. if bounds is not None:
  82. whl_bdy += [Assignment(wrt, Min(Max(wrt, bounds[0]), bounds[1]))]
  83. if debug:
  84. prnt = Print([wrt, delta], r"{}=%12.5g {}=%12.5g\n".format(wrt.name, name_d))
  85. whl_bdy += [prnt]
  86. req = Gt(Abs(delta), atol + rtol*Abs(wrt))
  87. declars = [Declaration(Variable(delta, type=real, value=oo))]
  88. if itermax is not None:
  89. counter = counter or Dummy(integer=True)
  90. v_counter = Variable.deduced(counter, 0)
  91. declars.append(Declaration(v_counter))
  92. whl_bdy.append(AddAugmentedAssignment(counter, 1))
  93. req = And(req, Lt(counter, itermax))
  94. whl = While(req, CodeBlock(*whl_bdy))
  95. blck = declars
  96. if debug:
  97. blck.append(Print([wrt], r"{}=%12.5g\n".format(wrt.name)))
  98. blck += [whl]
  99. return Wrapper(CodeBlock(*blck))
  100. def _symbol_of(arg):
  101. if isinstance(arg, Declaration):
  102. arg = arg.variable.symbol
  103. elif isinstance(arg, Variable):
  104. arg = arg.symbol
  105. return arg
  106. def newtons_method_function(expr, wrt, params=None, func_name="newton", attrs=Tuple(), *, delta=None, **kwargs):
  107. """ Generates an AST for a function implementing the Newton-Raphson method.
  108. Parameters
  109. ==========
  110. expr : expression
  111. wrt : Symbol
  112. With respect to, i.e. what is the variable
  113. params : iterable of symbols
  114. Symbols appearing in expr that are taken as constants during the iterations
  115. (these will be accepted as parameters to the generated function).
  116. func_name : str
  117. Name of the generated function.
  118. attrs : Tuple
  119. Attribute instances passed as ``attrs`` to ``FunctionDefinition``.
  120. \\*\\*kwargs :
  121. Keyword arguments passed to :func:`sympy.codegen.algorithms.newtons_method`.
  122. Examples
  123. ========
  124. >>> from sympy import symbols, cos
  125. >>> from sympy.codegen.algorithms import newtons_method_function
  126. >>> from sympy.codegen.pyutils import render_as_module
  127. >>> x = symbols('x')
  128. >>> expr = cos(x) - x**3
  129. >>> func = newtons_method_function(expr, x)
  130. >>> py_mod = render_as_module(func) # source code as string
  131. >>> namespace = {}
  132. >>> exec(py_mod, namespace, namespace)
  133. >>> res = eval('newton(0.5)', namespace)
  134. >>> abs(res - 0.865474033102) < 1e-12
  135. True
  136. See Also
  137. ========
  138. sympy.codegen.algorithms.newtons_method
  139. """
  140. if params is None:
  141. params = (wrt,)
  142. pointer_subs = {p.symbol: Symbol('(*%s)' % p.symbol.name)
  143. for p in params if isinstance(p, Pointer)}
  144. if delta is None:
  145. delta = Symbol('d_' + wrt.name)
  146. if expr.has(delta):
  147. delta = None # will use Dummy
  148. algo = newtons_method(expr, wrt, delta=delta, **kwargs).xreplace(pointer_subs)
  149. if isinstance(algo, Scope):
  150. algo = algo.body
  151. not_in_params = expr.free_symbols.difference({_symbol_of(p) for p in params})
  152. if not_in_params:
  153. raise ValueError("Missing symbols in params: %s" % ', '.join(map(str, not_in_params)))
  154. declars = tuple(Variable(p, real) for p in params)
  155. body = CodeBlock(algo, Return(wrt))
  156. return FunctionDefinition(real, func_name, declars, body, attrs=attrs)