test_algorithms.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import tempfile
  2. from sympy import log, Min, Max, sqrt
  3. from sympy.core.numbers import Float
  4. from sympy.core.symbol import Symbol, symbols
  5. from sympy.functions.elementary.trigonometric import cos
  6. from sympy.codegen.ast import Assignment, Raise, RuntimeError_, QuotedString
  7. from sympy.codegen.algorithms import newtons_method, newtons_method_function
  8. from sympy.codegen.cfunctions import expm1
  9. from sympy.codegen.fnodes import bind_C
  10. from sympy.codegen.futils import render_as_module as f_module
  11. from sympy.codegen.pyutils import render_as_module as py_module
  12. from sympy.external import import_module
  13. from sympy.printing.codeprinter import ccode
  14. from sympy.utilities._compilation import compile_link_import_strings, has_c, has_fortran
  15. from sympy.utilities._compilation.util import may_xfail
  16. from sympy.testing.pytest import skip, raises, skip_under_pyodide
  17. cython = import_module('cython')
  18. wurlitzer = import_module('wurlitzer')
  19. def test_newtons_method():
  20. x, dx, atol = symbols('x dx atol')
  21. expr = cos(x) - x**3
  22. algo = newtons_method(expr, x, atol, dx)
  23. assert algo.has(Assignment(dx, -expr/expr.diff(x)))
  24. @may_xfail
  25. def test_newtons_method_function__ccode():
  26. x = Symbol('x', real=True)
  27. expr = cos(x) - x**3
  28. func = newtons_method_function(expr, x)
  29. if not cython:
  30. skip("cython not installed.")
  31. if not has_c():
  32. skip("No C compiler found.")
  33. compile_kw = {"std": 'c99'}
  34. with tempfile.TemporaryDirectory() as folder:
  35. mod, info = compile_link_import_strings([
  36. ('newton.c', ('#include <math.h>\n'
  37. '#include <stdio.h>\n') + ccode(func)),
  38. ('_newton.pyx', ("#cython: language_level={}\n".format("3") +
  39. "cdef extern double newton(double)\n"
  40. "def py_newton(x):\n"
  41. " return newton(x)\n"))
  42. ], build_dir=folder, compile_kwargs=compile_kw)
  43. assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12
  44. @may_xfail
  45. def test_newtons_method_function__fcode():
  46. x = Symbol('x', real=True)
  47. expr = cos(x) - x**3
  48. func = newtons_method_function(expr, x, attrs=[bind_C(name='newton')])
  49. if not cython:
  50. skip("cython not installed.")
  51. if not has_fortran():
  52. skip("No Fortran compiler found.")
  53. f_mod = f_module([func], 'mod_newton')
  54. with tempfile.TemporaryDirectory() as folder:
  55. mod, info = compile_link_import_strings([
  56. ('newton.f90', f_mod),
  57. ('_newton.pyx', ("#cython: language_level={}\n".format("3") +
  58. "cdef extern double newton(double*)\n"
  59. "def py_newton(double x):\n"
  60. " return newton(&x)\n"))
  61. ], build_dir=folder)
  62. assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12
  63. def test_newtons_method_function__pycode():
  64. x = Symbol('x', real=True)
  65. expr = cos(x) - x**3
  66. func = newtons_method_function(expr, x)
  67. py_mod = py_module(func)
  68. namespace = {}
  69. exec(py_mod, namespace, namespace)
  70. res = eval('newton(0.5)', namespace)
  71. assert abs(res - 0.865474033102) < 1e-12
  72. @may_xfail
  73. @skip_under_pyodide("Emscripten does not support process spawning")
  74. def test_newtons_method_function__ccode_parameters():
  75. args = x, A, k, p = symbols('x A k p')
  76. expr = A*cos(k*x) - p*x**3
  77. raises(ValueError, lambda: newtons_method_function(expr, x))
  78. use_wurlitzer = wurlitzer
  79. func = newtons_method_function(expr, x, args, debug=use_wurlitzer)
  80. if not has_c():
  81. skip("No C compiler found.")
  82. if not cython:
  83. skip("cython not installed.")
  84. compile_kw = {"std": 'c99'}
  85. with tempfile.TemporaryDirectory() as folder:
  86. mod, info = compile_link_import_strings([
  87. ('newton_par.c', ('#include <math.h>\n'
  88. '#include <stdio.h>\n') + ccode(func)),
  89. ('_newton_par.pyx', ("#cython: language_level={}\n".format("3") +
  90. "cdef extern double newton(double, double, double, double)\n"
  91. "def py_newton(x, A=1, k=1, p=1):\n"
  92. " return newton(x, A, k, p)\n"))
  93. ], compile_kwargs=compile_kw, build_dir=folder)
  94. if use_wurlitzer:
  95. with wurlitzer.pipes() as (out, err):
  96. result = mod.py_newton(0.5)
  97. else:
  98. result = mod.py_newton(0.5)
  99. assert abs(result - 0.865474033102) < 1e-12
  100. if not use_wurlitzer:
  101. skip("C-level output only tested when package 'wurlitzer' is available.")
  102. out, err = out.read(), err.read()
  103. assert err == ''
  104. assert out == """\
  105. x= 0.5
  106. x= 1.1121 d_x= 0.61214
  107. x= 0.90967 d_x= -0.20247
  108. x= 0.86726 d_x= -0.042409
  109. x= 0.86548 d_x= -0.0017867
  110. x= 0.86547 d_x= -3.1022e-06
  111. x= 0.86547 d_x= -9.3421e-12
  112. x= 0.86547 d_x= 3.6902e-17
  113. """ # try to run tests with LC_ALL=C if this assertion fails
  114. def test_newtons_method_function__rtol_cse_nan():
  115. a, b, c, N_geo, N_tot = symbols('a b c N_geo N_tot', real=True, nonnegative=True)
  116. i = Symbol('i', integer=True, nonnegative=True)
  117. N_ari = N_tot - N_geo - 1
  118. delta_ari = (c-b)/N_ari
  119. ln_delta_geo = log(b) + log(-expm1((log(a)-log(b))/N_geo))
  120. eqb_log = ln_delta_geo - log(delta_ari)
  121. def _clamp(low, expr, high):
  122. return Min(Max(low, expr), high)
  123. meth_kw = {
  124. 'clamped_newton': {'delta_fn': lambda e, x: _clamp(
  125. (sqrt(a*x)-x)*0.99,
  126. -e/e.diff(x),
  127. (sqrt(c*x)-x)*0.99
  128. )},
  129. 'halley': {'delta_fn': lambda e, x: (-2*(e*e.diff(x))/(2*e.diff(x)**2 - e*e.diff(x, 2)))},
  130. 'halley_alt': {'delta_fn': lambda e, x: (-e/e.diff(x)/(1-e/e.diff(x)*e.diff(x,2)/2/e.diff(x)))},
  131. }
  132. args = eqb_log, b
  133. for use_cse in [False, True]:
  134. kwargs = {
  135. 'params': (b, a, c, N_geo, N_tot), 'itermax': 60, 'debug': True, 'cse': use_cse,
  136. 'counter': i, 'atol': 1e-100, 'rtol': 2e-16, 'bounds': (a,c),
  137. 'handle_nan': Raise(RuntimeError_(QuotedString("encountered NaN.")))
  138. }
  139. func = {k: newtons_method_function(*args, func_name=f"{k}_b", **dict(kwargs, **kw)) for k, kw in meth_kw.items()}
  140. py_mod = {k: py_module(v) for k, v in func.items()}
  141. namespace = {}
  142. root_find_b = {}
  143. for k, v in py_mod.items():
  144. ns = namespace[k] = {}
  145. exec(v, ns, ns)
  146. root_find_b[k] = ns[f'{k}_b']
  147. ref = Float('13.2261515064168768938151923226496')
  148. reftol = {'clamped_newton': 2e-16, 'halley': 2e-16, 'halley_alt': 3e-16}
  149. guess = 4.0
  150. for meth, func in root_find_b.items():
  151. result = func(guess, 1e-2, 1e2, 50, 100)
  152. req = ref*reftol[meth]
  153. if use_cse:
  154. req *= 2
  155. assert abs(result - ref) < req