test_zeros.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995
  1. import warnings
  2. from functools import lru_cache
  3. import pytest
  4. from numpy.testing import (assert_,
  5. assert_allclose,
  6. assert_equal,
  7. assert_array_equal)
  8. import numpy as np
  9. from numpy import finfo, power, nan, isclose, sqrt, exp, sin, cos
  10. from scipy import optimize
  11. from scipy.optimize import (_zeros_py as zeros, newton, root_scalar,
  12. OptimizeResult)
  13. from scipy._lib._util import getfullargspec_no_self as _getfullargspec
  14. # Import testing parameters
  15. from scipy.optimize._tstutils import get_tests, functions as tstutils_functions
  16. TOL = 4*np.finfo(float).eps # tolerance
  17. _FLOAT_EPS = finfo(float).eps
  18. bracket_methods = [zeros.bisect, zeros.ridder, zeros.brentq, zeros.brenth,
  19. zeros.toms748]
  20. gradient_methods = [zeros.newton]
  21. all_methods = bracket_methods + gradient_methods
  22. # A few test functions used frequently:
  23. # # A simple quadratic, (x-1)^2 - 1
  24. def f1(x):
  25. return x ** 2 - 2 * x - 1
  26. def f1_1(x):
  27. return 2 * x - 2
  28. def f1_2(x):
  29. return 2.0 + 0 * x
  30. def f1_and_p_and_pp(x):
  31. return f1(x), f1_1(x), f1_2(x)
  32. # Simple transcendental function
  33. def f2(x):
  34. return exp(x) - cos(x)
  35. def f2_1(x):
  36. return exp(x) + sin(x)
  37. def f2_2(x):
  38. return exp(x) + cos(x)
  39. # lru cached function
  40. @lru_cache
  41. def f_lrucached(x):
  42. return x
  43. class TestScalarRootFinders:
  44. # Basic tests for all scalar root finders
  45. xtol = 4 * np.finfo(float).eps
  46. rtol = 4 * np.finfo(float).eps
  47. def _run_one_test(self, tc, method, sig_args_keys=None,
  48. sig_kwargs_keys=None, **kwargs):
  49. method_args = []
  50. for k in sig_args_keys or []:
  51. if k not in tc:
  52. # If a,b not present use x0, x1. Similarly for f and func
  53. k = {'a': 'x0', 'b': 'x1', 'func': 'f'}.get(k, k)
  54. method_args.append(tc[k])
  55. method_kwargs = dict(**kwargs)
  56. method_kwargs.update({'full_output': True, 'disp': False})
  57. for k in sig_kwargs_keys or []:
  58. method_kwargs[k] = tc[k]
  59. root = tc.get('root')
  60. func_args = tc.get('args', ())
  61. try:
  62. r, rr = method(*method_args, args=func_args, **method_kwargs)
  63. return root, rr, tc
  64. except Exception:
  65. return root, zeros.RootResults(nan, -1, -1, zeros._EVALUEERR, method), tc
  66. def run_tests(self, tests, method, name, known_fail=None, **kwargs):
  67. r"""Run test-cases using the specified method and the supplied signature.
  68. Extract the arguments for the method call from the test case
  69. dictionary using the supplied keys for the method's signature."""
  70. # The methods have one of two base signatures:
  71. # (f, a, b, **kwargs) # newton
  72. # (func, x0, **kwargs) # bisect/brentq/...
  73. # FullArgSpec with args, varargs, varkw, defaults, ...
  74. sig = _getfullargspec(method)
  75. assert_(not sig.kwonlyargs)
  76. nDefaults = len(sig.defaults)
  77. nRequired = len(sig.args) - nDefaults
  78. sig_args_keys = sig.args[:nRequired]
  79. sig_kwargs_keys = []
  80. if name in ['secant', 'newton', 'halley']:
  81. if name in ['newton', 'halley']:
  82. sig_kwargs_keys.append('fprime')
  83. if name in ['halley']:
  84. sig_kwargs_keys.append('fprime2')
  85. kwargs['tol'] = self.xtol
  86. else:
  87. kwargs['xtol'] = self.xtol
  88. kwargs['rtol'] = self.rtol
  89. results = [list(self._run_one_test(
  90. tc, method, sig_args_keys=sig_args_keys,
  91. sig_kwargs_keys=sig_kwargs_keys, **kwargs)) for tc in tests]
  92. # results= [[true root, full output, tc], ...]
  93. known_fail = known_fail or []
  94. notcvgd = [elt for elt in results if not elt[1].converged]
  95. notcvgd = [elt for elt in notcvgd if elt[-1]['ID'] not in known_fail]
  96. notcvged_IDS = [elt[-1]['ID'] for elt in notcvgd]
  97. assert_equal([len(notcvged_IDS), notcvged_IDS], [0, []])
  98. # The usable xtol and rtol depend on the test
  99. tols = {'xtol': self.xtol, 'rtol': self.rtol}
  100. tols.update(**kwargs)
  101. rtol = tols['rtol']
  102. atol = tols.get('tol', tols['xtol'])
  103. cvgd = [elt for elt in results if elt[1].converged]
  104. approx = [elt[1].root for elt in cvgd]
  105. correct = [elt[0] for elt in cvgd]
  106. # See if the root matches the reference value
  107. notclose = [[a] + elt for a, c, elt in zip(approx, correct, cvgd) if
  108. not isclose(a, c, rtol=rtol, atol=atol)
  109. and elt[-1]['ID'] not in known_fail]
  110. # If not, evaluate the function and see if is 0 at the purported root
  111. fvs = [tc['f'](aroot, *tc.get('args', tuple()))
  112. for aroot, c, fullout, tc in notclose]
  113. notclose = [[fv] + elt for fv, elt in zip(fvs, notclose) if fv != 0]
  114. assert_equal([notclose, len(notclose)], [[], 0])
  115. method_from_result = [result[1].method for result in results]
  116. expected_method = [name for _ in results]
  117. assert_equal(method_from_result, expected_method)
  118. def run_collection(self, collection, method, name, smoothness=None,
  119. known_fail=None, **kwargs):
  120. r"""Run a collection of tests using the specified method.
  121. The name is used to determine some optional arguments."""
  122. tests = get_tests(collection, smoothness=smoothness)
  123. self.run_tests(tests, method, name, known_fail=known_fail, **kwargs)
  124. class TestBracketMethods(TestScalarRootFinders):
  125. @pytest.mark.parametrize('method', bracket_methods)
  126. @pytest.mark.parametrize('function', tstutils_functions)
  127. def test_basic_root_scalar(self, method, function):
  128. # Tests bracketing root finders called via `root_scalar` on a small
  129. # set of simple problems, each of which has a root at `x=1`. Checks for
  130. # converged status and that the root was found.
  131. a, b = .5, sqrt(3)
  132. r = root_scalar(function, method=method.__name__, bracket=[a, b], x0=a,
  133. xtol=self.xtol, rtol=self.rtol)
  134. assert r.converged
  135. assert_allclose(r.root, 1.0, atol=self.xtol, rtol=self.rtol)
  136. assert r.method == method.__name__
  137. @pytest.mark.parametrize('method', bracket_methods)
  138. @pytest.mark.parametrize('function', tstutils_functions)
  139. def test_basic_individual(self, method, function):
  140. # Tests individual bracketing root finders on a small set of simple
  141. # problems, each of which has a root at `x=1`. Checks for converged
  142. # status and that the root was found.
  143. a, b = .5, sqrt(3)
  144. root, r = method(function, a, b, xtol=self.xtol, rtol=self.rtol,
  145. full_output=True)
  146. assert r.converged
  147. assert_allclose(root, 1.0, atol=self.xtol, rtol=self.rtol)
  148. @pytest.mark.parametrize('method', bracket_methods)
  149. @pytest.mark.parametrize('function', tstutils_functions)
  150. def test_bracket_is_array(self, method, function):
  151. # Test bracketing root finders called via `root_scalar` on a small set
  152. # of simple problems, each of which has a root at `x=1`. Check that
  153. # passing `bracket` as a `ndarray` is accepted and leads to finding the
  154. # correct root.
  155. a, b = .5, sqrt(3)
  156. r = root_scalar(function, method=method.__name__,
  157. bracket=np.array([a, b]), x0=a, xtol=self.xtol,
  158. rtol=self.rtol)
  159. assert r.converged
  160. assert_allclose(r.root, 1.0, atol=self.xtol, rtol=self.rtol)
  161. assert r.method == method.__name__
  162. @pytest.mark.parametrize('method', bracket_methods)
  163. def test_aps_collection(self, method):
  164. self.run_collection('aps', method, method.__name__, smoothness=1)
  165. @pytest.mark.parametrize('method', [zeros.bisect, zeros.ridder,
  166. zeros.toms748])
  167. def test_chandrupatla_collection(self, method):
  168. known_fail = {'fun7.4'} if method == zeros.ridder else {}
  169. self.run_collection('chandrupatla', method, method.__name__,
  170. known_fail=known_fail)
  171. @pytest.mark.parametrize('method', bracket_methods)
  172. def test_lru_cached_individual(self, method):
  173. # check that https://github.com/scipy/scipy/issues/10846 is fixed
  174. # (`root_scalar` failed when passed a function that was `@lru_cache`d)
  175. a, b = -1, 1
  176. root, r = method(f_lrucached, a, b, full_output=True)
  177. assert r.converged
  178. assert_allclose(root, 0)
  179. def test_gh_22934(self):
  180. with pytest.raises(ValueError, match="maxiter must be >= 0"):
  181. zeros.brentq(lambda x: x**2 - 1, -2, 0, maxiter=-1)
  182. class TestNewton(TestScalarRootFinders):
  183. def test_newton_collections(self):
  184. known_fail = ['aps.13.00']
  185. known_fail += ['aps.12.05', 'aps.12.17'] # fails under Windows Py27
  186. for collection in ['aps', 'complex']:
  187. self.run_collection(collection, zeros.newton, 'newton',
  188. smoothness=2, known_fail=known_fail)
  189. def test_halley_collections(self):
  190. known_fail = ['aps.12.06', 'aps.12.07', 'aps.12.08', 'aps.12.09',
  191. 'aps.12.10', 'aps.12.11', 'aps.12.12', 'aps.12.13',
  192. 'aps.12.14', 'aps.12.15', 'aps.12.16', 'aps.12.17',
  193. 'aps.12.18', 'aps.13.00']
  194. for collection in ['aps', 'complex']:
  195. self.run_collection(collection, zeros.newton, 'halley',
  196. smoothness=2, known_fail=known_fail)
  197. def test_newton(self):
  198. for f, f_1, f_2 in [(f1, f1_1, f1_2), (f2, f2_1, f2_2)]:
  199. x = zeros.newton(f, 3, tol=1e-6)
  200. assert_allclose(f(x), 0, atol=1e-6)
  201. x = zeros.newton(f, 3, x1=5, tol=1e-6) # secant, x0 and x1
  202. assert_allclose(f(x), 0, atol=1e-6)
  203. x = zeros.newton(f, 3, fprime=f_1, tol=1e-6) # newton
  204. assert_allclose(f(x), 0, atol=1e-6)
  205. x = zeros.newton(f, 3, fprime=f_1, fprime2=f_2, tol=1e-6) # halley
  206. assert_allclose(f(x), 0, atol=1e-6)
  207. def test_newton_by_name(self):
  208. r"""Invoke newton through root_scalar()"""
  209. for f, f_1, f_2 in [(f1, f1_1, f1_2), (f2, f2_1, f2_2)]:
  210. r = root_scalar(f, method='newton', x0=3, fprime=f_1, xtol=1e-6)
  211. assert_allclose(f(r.root), 0, atol=1e-6)
  212. for f, f_1, f_2 in [(f1, f1_1, f1_2), (f2, f2_1, f2_2)]:
  213. r = root_scalar(f, method='newton', x0=3, xtol=1e-6) # without f'
  214. assert_allclose(f(r.root), 0, atol=1e-6)
  215. def test_secant_by_name(self):
  216. r"""Invoke secant through root_scalar()"""
  217. for f, f_1, f_2 in [(f1, f1_1, f1_2), (f2, f2_1, f2_2)]:
  218. r = root_scalar(f, method='secant', x0=3, x1=2, xtol=1e-6)
  219. assert_allclose(f(r.root), 0, atol=1e-6)
  220. r = root_scalar(f, method='secant', x0=3, x1=5, xtol=1e-6)
  221. assert_allclose(f(r.root), 0, atol=1e-6)
  222. for f, f_1, f_2 in [(f1, f1_1, f1_2), (f2, f2_1, f2_2)]:
  223. r = root_scalar(f, method='secant', x0=3, xtol=1e-6) # without x1
  224. assert_allclose(f(r.root), 0, atol=1e-6)
  225. def test_halley_by_name(self):
  226. r"""Invoke halley through root_scalar()"""
  227. for f, f_1, f_2 in [(f1, f1_1, f1_2), (f2, f2_1, f2_2)]:
  228. r = root_scalar(f, method='halley', x0=3,
  229. fprime=f_1, fprime2=f_2, xtol=1e-6)
  230. assert_allclose(f(r.root), 0, atol=1e-6)
  231. def test_root_scalar_fail(self):
  232. message = 'fprime2 must be specified for halley'
  233. with pytest.raises(ValueError, match=message):
  234. root_scalar(f1, method='halley', fprime=f1_1, x0=3, xtol=1e-6) # no fprime2
  235. message = 'fprime must be specified for halley'
  236. with pytest.raises(ValueError, match=message):
  237. root_scalar(f1, method='halley', fprime2=f1_2, x0=3, xtol=1e-6) # no fprime
  238. def test_array_newton(self):
  239. """test newton with array"""
  240. def f1(x, *a):
  241. b = a[0] + x * a[3]
  242. return a[1] - a[2] * (np.exp(b / a[5]) - 1.0) - b / a[4] - x
  243. def f1_1(x, *a):
  244. b = a[3] / a[5]
  245. return -a[2] * np.exp(a[0] / a[5] + x * b) * b - a[3] / a[4] - 1
  246. def f1_2(x, *a):
  247. b = a[3] / a[5]
  248. return -a[2] * np.exp(a[0] / a[5] + x * b) * b**2
  249. a0 = np.array([
  250. 5.32725221, 5.48673747, 5.49539973,
  251. 5.36387202, 4.80237316, 1.43764452,
  252. 5.23063958, 5.46094772, 5.50512718,
  253. 5.42046290
  254. ])
  255. a1 = (np.sin(range(10)) + 1.0) * 7.0
  256. args = (a0, a1, 1e-09, 0.004, 10, 0.27456)
  257. x0 = [7.0] * 10
  258. x = zeros.newton(f1, x0, f1_1, args)
  259. x_expected = (
  260. 6.17264965, 11.7702805, 12.2219954,
  261. 7.11017681, 1.18151293, 0.143707955,
  262. 4.31928228, 10.5419107, 12.7552490,
  263. 8.91225749
  264. )
  265. assert_allclose(x, x_expected)
  266. # test halley's
  267. x = zeros.newton(f1, x0, f1_1, args, fprime2=f1_2)
  268. assert_allclose(x, x_expected)
  269. # test secant
  270. x = zeros.newton(f1, x0, args=args)
  271. assert_allclose(x, x_expected)
  272. def test_array_newton_complex(self):
  273. def f(x):
  274. return x + 1+1j
  275. def fprime(x):
  276. return 1.0
  277. t = np.full(4, 1j)
  278. x = zeros.newton(f, t, fprime=fprime)
  279. assert_allclose(f(x), 0.)
  280. # should work even if x0 is not complex
  281. t = np.ones(4)
  282. x = zeros.newton(f, t, fprime=fprime)
  283. assert_allclose(f(x), 0.)
  284. x = zeros.newton(f, t)
  285. assert_allclose(f(x), 0.)
  286. def test_array_secant_active_zero_der(self):
  287. """test secant doesn't continue to iterate zero derivatives"""
  288. x = zeros.newton(lambda x, *a: x*x - a[0], x0=[4.123, 5],
  289. args=[np.array([17, 25])])
  290. assert_allclose(x, (4.123105625617661, 5.0))
  291. def test_array_newton_integers(self):
  292. # test secant with float
  293. x = zeros.newton(lambda y, z: z - y ** 2, [4.0] * 2,
  294. args=([15.0, 17.0],))
  295. assert_allclose(x, (3.872983346207417, 4.123105625617661))
  296. # test integer becomes float
  297. x = zeros.newton(lambda y, z: z - y ** 2, [4] * 2, args=([15, 17],))
  298. assert_allclose(x, (3.872983346207417, 4.123105625617661))
  299. def test_array_newton_zero_der_failures(self):
  300. # test derivative zero warning
  301. with pytest.warns(RuntimeWarning):
  302. zeros.newton(lambda y: y**2 - 2, [0., 0.], lambda y: 2 * y)
  303. # test failures and zero_der
  304. with pytest.warns(RuntimeWarning):
  305. results = zeros.newton(lambda y: y**2 - 2, [0., 0.],
  306. lambda y: 2*y, full_output=True)
  307. assert_allclose(results.root, 0)
  308. assert results.zero_der.all()
  309. assert not results.converged.any()
  310. def test_newton_combined(self):
  311. def f1(x):
  312. return x ** 2 - 2 * x - 1
  313. def f1_1(x):
  314. return 2 * x - 2
  315. def f1_2(x):
  316. return 2.0 + 0 * x
  317. def f1_and_p_and_pp(x):
  318. return x**2 - 2*x-1, 2*x-2, 2.0
  319. sol0 = root_scalar(f1, method='newton', x0=3, fprime=f1_1)
  320. sol = root_scalar(f1_and_p_and_pp, method='newton', x0=3, fprime=True)
  321. assert_allclose(sol0.root, sol.root, atol=1e-8)
  322. assert_equal(2*sol.function_calls, sol0.function_calls)
  323. sol0 = root_scalar(f1, method='halley', x0=3, fprime=f1_1, fprime2=f1_2)
  324. sol = root_scalar(f1_and_p_and_pp, method='halley', x0=3, fprime2=True)
  325. assert_allclose(sol0.root, sol.root, atol=1e-8)
  326. assert_equal(3*sol.function_calls, sol0.function_calls)
  327. def test_newton_full_output(self, capsys):
  328. # Test the full_output capability, both when converging and not.
  329. # Use simple polynomials, to avoid hitting platform dependencies
  330. # (e.g., exp & trig) in number of iterations
  331. x0 = 3
  332. expected_counts = [(6, 7), (5, 10), (3, 9)]
  333. for derivs in range(3):
  334. kwargs = {'tol': 1e-6, 'full_output': True, }
  335. for k, v in [['fprime', f1_1], ['fprime2', f1_2]][:derivs]:
  336. kwargs[k] = v
  337. x, r = zeros.newton(f1, x0, disp=False, **kwargs)
  338. assert_(r.converged)
  339. assert_equal(x, r.root)
  340. assert_equal((r.iterations, r.function_calls), expected_counts[derivs])
  341. if derivs == 0:
  342. assert r.function_calls <= r.iterations + 1
  343. else:
  344. assert_equal(r.function_calls, (derivs + 1) * r.iterations)
  345. # Now repeat, allowing one fewer iteration to force convergence failure
  346. iters = r.iterations - 1
  347. x, r = zeros.newton(f1, x0, maxiter=iters, disp=False, **kwargs)
  348. assert_(not r.converged)
  349. assert_equal(x, r.root)
  350. assert_equal(r.iterations, iters)
  351. if derivs == 1:
  352. # Check that the correct Exception is raised and
  353. # validate the start of the message.
  354. msg = f'Failed to converge after {iters} iterations, value is .*'
  355. with pytest.raises(RuntimeError, match=msg):
  356. x, r = zeros.newton(f1, x0, maxiter=iters, disp=True, **kwargs)
  357. def test_deriv_zero_warning(self):
  358. def func(x):
  359. return x ** 2 - 2.0
  360. def dfunc(x):
  361. return 2 * x
  362. with pytest.warns(RuntimeWarning):
  363. zeros.newton(func, 0.0, dfunc, disp=False)
  364. with pytest.raises(RuntimeError, match='Derivative was zero'):
  365. zeros.newton(func, 0.0, dfunc)
  366. def test_newton_does_not_modify_x0(self):
  367. # https://github.com/scipy/scipy/issues/9964
  368. x0 = np.array([0.1, 3])
  369. x0_copy = x0.copy() # Copy to test for equality.
  370. newton(np.sin, x0, np.cos)
  371. assert_array_equal(x0, x0_copy)
  372. def test_gh17570_defaults(self):
  373. # Previously, when fprime was not specified, root_scalar would default
  374. # to secant. When x1 was not specified, secant failed.
  375. # Check that without fprime, the default is secant if x1 is specified
  376. # and newton otherwise.
  377. # Also confirm that `x` is always a scalar (gh-21148)
  378. def f(x):
  379. assert np.isscalar(x)
  380. return f1(x)
  381. res_newton_default = root_scalar(f, method='newton', x0=3, xtol=1e-6)
  382. res_secant_default = root_scalar(f, method='secant', x0=3, x1=2,
  383. xtol=1e-6)
  384. # `newton` uses the secant method when `x1` and `x2` are specified
  385. res_secant = newton(f, x0=3, x1=2, tol=1e-6, full_output=True)[1]
  386. # all three found a root
  387. assert_allclose(f(res_newton_default.root), 0, atol=1e-6)
  388. assert res_newton_default.root.shape == tuple()
  389. assert_allclose(f(res_secant_default.root), 0, atol=1e-6)
  390. assert res_secant_default.root.shape == tuple()
  391. assert_allclose(f(res_secant.root), 0, atol=1e-6)
  392. assert res_secant.root.shape == tuple()
  393. # Defaults are correct
  394. assert (res_secant_default.root
  395. == res_secant.root
  396. != res_newton_default.iterations)
  397. assert (res_secant_default.iterations
  398. == res_secant_default.function_calls - 1 # true for secant
  399. == res_secant.iterations
  400. != res_newton_default.iterations
  401. == res_newton_default.function_calls/2) # newton 2-point diff
  402. @pytest.mark.parametrize('kwargs', [dict(), {'method': 'newton'}])
  403. def test_args_gh19090(self, kwargs):
  404. def f(x, a, b):
  405. assert a == 3
  406. assert b == 1
  407. return (x ** a - b)
  408. res = optimize.root_scalar(f, x0=3, args=(3, 1), **kwargs)
  409. assert res.converged
  410. assert_allclose(res.root, 1)
  411. @pytest.mark.parametrize('method', ['secant', 'newton'])
  412. def test_int_x0_gh19280(self, method):
  413. # Originally, `newton` ensured that only floats were passed to the
  414. # callable. This was inadvertently changed by gh-17669. Check that
  415. # it has been changed back.
  416. def f(x):
  417. # an integer raised to a negative integer power would fail
  418. return x**-2 - 2
  419. res = optimize.root_scalar(f, x0=1, method=method)
  420. assert res.converged
  421. assert_allclose(abs(res.root), 2**-0.5)
  422. assert res.root.dtype == np.dtype(np.float64)
  423. def test_newton_special_parameters(self):
  424. # give zeros.newton() some strange parameters
  425. # and check whether an exception appears
  426. with pytest.raises(ValueError, match="tol too small"):
  427. zeros.newton(f1, 3, tol=-1e-6)
  428. with pytest.raises(ValueError, match="maxiter must be greater than 0"):
  429. zeros.newton(f1, 3, tol=1e-6, maxiter=-50)
  430. with pytest.raises(ValueError, match="x1 and x0 must be different" ):
  431. zeros.newton(f1, 3, x1=3)
  432. def test_gh_5555():
  433. root = 0.1
  434. def f(x):
  435. return x - root
  436. methods = [zeros.bisect, zeros.ridder]
  437. xtol = rtol = TOL
  438. for method in methods:
  439. res = method(f, -1e8, 1e7, xtol=xtol, rtol=rtol)
  440. assert_allclose(root, res, atol=xtol, rtol=rtol,
  441. err_msg=f'method {method.__name__}')
  442. def test_gh_5557():
  443. # Show that without the changes in 5557 brentq and brenth might
  444. # only achieve a tolerance of 2*(xtol + rtol*|res|).
  445. # f linearly interpolates (0, -0.1), (0.5, -0.1), and (1,
  446. # 0.4). The important parts are that |f(0)| < |f(1)| (so that
  447. # brent takes 0 as the initial guess), |f(0)| < atol (so that
  448. # brent accepts 0 as the root), and that the exact root of f lies
  449. # more than atol away from 0 (so that brent doesn't achieve the
  450. # desired tolerance).
  451. def f(x):
  452. if x < 0.5:
  453. return -0.1
  454. else:
  455. return x - 0.6
  456. atol = 0.51
  457. rtol = 4 * _FLOAT_EPS
  458. methods = [zeros.brentq, zeros.brenth]
  459. for method in methods:
  460. res = method(f, 0, 1, xtol=atol, rtol=rtol)
  461. assert_allclose(0.6, res, atol=atol, rtol=rtol)
  462. def test_brent_underflow_in_root_bracketing():
  463. # Testing if an interval [a,b] brackets a zero of a function
  464. # by checking f(a)*f(b) < 0 is not reliable when the product
  465. # underflows/overflows. (reported in issue# 13737)
  466. underflow_scenario = (-450.0, -350.0, -400.0)
  467. overflow_scenario = (350.0, 450.0, 400.0)
  468. for a, b, root in [underflow_scenario, overflow_scenario]:
  469. c = np.exp(root)
  470. for method in [zeros.brenth, zeros.brentq]:
  471. res = method(lambda x: np.exp(x)-c, a, b)
  472. assert_allclose(root, res)
  473. class TestRootResults:
  474. r = zeros.RootResults(root=1.0, iterations=44, function_calls=46, flag=0,
  475. method="newton")
  476. def test_repr(self):
  477. expected_repr = (" converged: True\n flag: converged"
  478. "\n function_calls: 46\n iterations: 44\n"
  479. " root: 1.0\n method: newton")
  480. assert_equal(repr(self.r), expected_repr)
  481. def test_type(self):
  482. assert isinstance(self.r, OptimizeResult)
  483. def test_complex_halley():
  484. """Test Halley's works with complex roots"""
  485. def f(x, *a):
  486. return a[0] * x**2 + a[1] * x + a[2]
  487. def f_1(x, *a):
  488. return 2 * a[0] * x + a[1]
  489. def f_2(x, *a):
  490. retval = 2 * a[0]
  491. try:
  492. size = len(x)
  493. except TypeError:
  494. return retval
  495. else:
  496. return [retval] * size
  497. z = complex(1.0, 2.0)
  498. coeffs = (2.0, 3.0, 4.0)
  499. y = zeros.newton(f, z, args=coeffs, fprime=f_1, fprime2=f_2, tol=1e-6)
  500. # (-0.75000000000000078+1.1989578808281789j)
  501. assert_allclose(f(y, *coeffs), 0, atol=1e-6)
  502. z = [z] * 10
  503. coeffs = (2.0, 3.0, 4.0)
  504. y = zeros.newton(f, z, args=coeffs, fprime=f_1, fprime2=f_2, tol=1e-6)
  505. assert_allclose(f(y, *coeffs), 0, atol=1e-6)
  506. def test_zero_der_nz_dp(capsys):
  507. """Test secant method with a non-zero dp, but an infinite newton step"""
  508. # pick a symmetrical functions and choose a point on the side that with dx
  509. # makes a secant that is a flat line with zero slope, EG: f = (x - 100)**2,
  510. # which has a root at x = 100 and is symmetrical around the line x = 100
  511. # we have to pick a really big number so that it is consistently true
  512. # now find a point on each side so that the secant has a zero slope
  513. dx = np.finfo(float).eps ** 0.33
  514. # 100 - p0 = p1 - 100 = p0 * (1 + dx) + dx - 100
  515. # -> 200 = p0 * (2 + dx) + dx
  516. p0 = (200.0 - dx) / (2.0 + dx)
  517. with warnings.catch_warnings():
  518. warnings.filterwarnings("ignore", "RMS of", RuntimeWarning)
  519. x = zeros.newton(lambda y: (y - 100.0)**2, x0=[p0] * 10)
  520. assert_allclose(x, [100] * 10)
  521. # test scalar cases too
  522. p0 = (2.0 - 1e-4) / (2.0 + 1e-4)
  523. with warnings.catch_warnings():
  524. warnings.filterwarnings("ignore", "Tolerance of", RuntimeWarning)
  525. x = zeros.newton(lambda y: (y - 1.0) ** 2, x0=p0, disp=False)
  526. assert_allclose(x, 1)
  527. with pytest.raises(RuntimeError, match='Tolerance of'):
  528. x = zeros.newton(lambda y: (y - 1.0) ** 2, x0=p0, disp=True)
  529. p0 = (-2.0 + 1e-4) / (2.0 + 1e-4)
  530. with warnings.catch_warnings():
  531. warnings.filterwarnings("ignore", "Tolerance of", RuntimeWarning)
  532. x = zeros.newton(lambda y: (y + 1.0) ** 2, x0=p0, disp=False)
  533. assert_allclose(x, -1)
  534. with pytest.raises(RuntimeError, match='Tolerance of'):
  535. x = zeros.newton(lambda y: (y + 1.0) ** 2, x0=p0, disp=True)
  536. def test_array_newton_failures():
  537. """Test that array newton fails as expected"""
  538. # p = 0.68 # [MPa]
  539. # dp = -0.068 * 1e6 # [Pa]
  540. # T = 323 # [K]
  541. diameter = 0.10 # [m]
  542. # L = 100 # [m]
  543. roughness = 0.00015 # [m]
  544. rho = 988.1 # [kg/m**3]
  545. mu = 5.4790e-04 # [Pa*s]
  546. u = 2.488 # [m/s]
  547. reynolds_number = rho * u * diameter / mu # Reynolds number
  548. def colebrook_eqn(darcy_friction, re, dia):
  549. return (1 / np.sqrt(darcy_friction) +
  550. 2 * np.log10(roughness / 3.7 / dia +
  551. 2.51 / re / np.sqrt(darcy_friction)))
  552. # only some failures
  553. with pytest.warns(RuntimeWarning):
  554. result = zeros.newton(
  555. colebrook_eqn, x0=[0.01, 0.2, 0.02223, 0.3], maxiter=2,
  556. args=[reynolds_number, diameter], full_output=True
  557. )
  558. assert not result.converged.all()
  559. # they all fail
  560. with pytest.raises(RuntimeError):
  561. result = zeros.newton(
  562. colebrook_eqn, x0=[0.01] * 2, maxiter=2,
  563. args=[reynolds_number, diameter], full_output=True
  564. )
  565. # this test should **not** raise a RuntimeWarning
  566. def test_gh8904_zeroder_at_root_fails():
  567. """Test that Newton or Halley don't warn if zero derivative at root"""
  568. # a function that has a zero derivative at it's root
  569. def f_zeroder_root(x):
  570. return x**3 - x**2
  571. # should work with secant
  572. r = zeros.newton(f_zeroder_root, x0=0)
  573. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  574. # test again with array
  575. r = zeros.newton(f_zeroder_root, x0=[0]*10)
  576. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  577. # 1st derivative
  578. def fder(x):
  579. return 3 * x**2 - 2 * x
  580. # 2nd derivative
  581. def fder2(x):
  582. return 6*x - 2
  583. # should work with newton and halley
  584. r = zeros.newton(f_zeroder_root, x0=0, fprime=fder)
  585. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  586. r = zeros.newton(f_zeroder_root, x0=0, fprime=fder,
  587. fprime2=fder2)
  588. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  589. # test again with array
  590. r = zeros.newton(f_zeroder_root, x0=[0]*10, fprime=fder)
  591. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  592. r = zeros.newton(f_zeroder_root, x0=[0]*10, fprime=fder,
  593. fprime2=fder2)
  594. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  595. # also test that if a root is found we do not raise RuntimeWarning even if
  596. # the derivative is zero, EG: at x = 0.5, then fval = -0.125 and
  597. # fder = -0.25 so the next guess is 0.5 - (-0.125/-0.5) = 0 which is the
  598. # root, but if the solver continued with that guess, then it will calculate
  599. # a zero derivative, so it should return the root w/o RuntimeWarning
  600. r = zeros.newton(f_zeroder_root, x0=0.5, fprime=fder)
  601. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  602. # test again with array
  603. r = zeros.newton(f_zeroder_root, x0=[0.5]*10, fprime=fder)
  604. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  605. # doesn't apply to halley
  606. def test_gh_8881():
  607. r"""Test that Halley's method realizes that the 2nd order adjustment
  608. is too big and drops off to the 1st order adjustment."""
  609. n = 9
  610. def f(x):
  611. return power(x, 1.0/n) - power(n, 1.0/n)
  612. def fp(x):
  613. return power(x, (1.0-n)/n)/n
  614. def fpp(x):
  615. return power(x, (1.0-2*n)/n) * (1.0/n) * (1.0-n)/n
  616. x0 = 0.1
  617. # The root is at x=9.
  618. # The function has positive slope, x0 < root.
  619. # Newton succeeds in 8 iterations
  620. rt, r = newton(f, x0, fprime=fp, full_output=True)
  621. assert r.converged
  622. # Before the Issue 8881/PR 8882, halley would send x in the wrong direction.
  623. # Check that it now succeeds.
  624. rt, r = newton(f, x0, fprime=fp, fprime2=fpp, full_output=True)
  625. assert r.converged
  626. def test_gh_9608_preserve_array_shape():
  627. """
  628. Test that shape is preserved for array inputs even if fprime or fprime2 is
  629. scalar
  630. """
  631. def f(x):
  632. return x**2
  633. def fp(x):
  634. return 2 * x
  635. def fpp(x):
  636. return 2
  637. x0 = np.array([-2], dtype=np.float32)
  638. rt, r = newton(f, x0, fprime=fp, fprime2=fpp, full_output=True)
  639. assert r.converged
  640. x0_array = np.array([-2, -3], dtype=np.float32)
  641. # This next invocation should fail
  642. with pytest.raises(IndexError):
  643. result = zeros.newton(
  644. f, x0_array, fprime=fp, fprime2=fpp, full_output=True
  645. )
  646. def fpp_array(x):
  647. return np.full(np.shape(x), 2, dtype=np.float32)
  648. result = zeros.newton(
  649. f, x0_array, fprime=fp, fprime2=fpp_array, full_output=True
  650. )
  651. assert result.converged.all()
  652. @pytest.mark.parametrize(
  653. "maximum_iterations,flag_expected",
  654. [(10, zeros.CONVERR), (100, zeros.CONVERGED)])
  655. def test_gh9254_flag_if_maxiter_exceeded(maximum_iterations, flag_expected):
  656. """
  657. Test that if the maximum iterations is exceeded that the flag is not
  658. converged.
  659. """
  660. result = zeros.brentq(
  661. lambda x: ((1.2*x - 2.3)*x + 3.4)*x - 4.5,
  662. -30, 30, (), 1e-6, 1e-6, maximum_iterations,
  663. full_output=True, disp=False)
  664. assert result[1].flag == flag_expected
  665. if flag_expected == zeros.CONVERR:
  666. # didn't converge because exceeded maximum iterations
  667. assert result[1].iterations == maximum_iterations
  668. elif flag_expected == zeros.CONVERGED:
  669. # converged before maximum iterations
  670. assert result[1].iterations < maximum_iterations
  671. def test_gh9551_raise_error_if_disp_true():
  672. """Test that if disp is true then zero derivative raises RuntimeError"""
  673. def f(x):
  674. return x*x + 1
  675. def f_p(x):
  676. return 2*x
  677. with pytest.warns(RuntimeWarning):
  678. zeros.newton(f, 1.0, f_p, disp=False)
  679. with pytest.raises(
  680. RuntimeError,
  681. match=r'^Derivative was zero\. Failed to converge after \d+ iterations, '
  682. r'value is [+-]?\d*\.\d+\.$'):
  683. zeros.newton(f, 1.0, f_p)
  684. root = zeros.newton(f, complex(10.0, 10.0), f_p)
  685. assert_allclose(root, complex(0.0, 1.0))
  686. @pytest.mark.parametrize('solver_name',
  687. ['brentq', 'brenth', 'bisect', 'ridder', 'toms748'])
  688. def test_gh3089_8394(solver_name):
  689. # gh-3089 and gh-8394 reported that bracketing solvers returned incorrect
  690. # results when they encountered NaNs. Check that this is resolved.
  691. def f(x):
  692. return np.nan
  693. solver = getattr(zeros, solver_name)
  694. with pytest.raises(ValueError, match="The function value at x..."):
  695. solver(f, 0, 1)
  696. @pytest.mark.parametrize('method',
  697. ['brentq', 'brenth', 'bisect', 'ridder', 'toms748'])
  698. def test_gh18171(method):
  699. # gh-3089 and gh-8394 reported that bracketing solvers returned incorrect
  700. # results when they encountered NaNs. Check that `root_scalar` returns
  701. # normally but indicates that convergence was unsuccessful. See gh-18171.
  702. def f(x):
  703. f._count += 1
  704. return np.nan
  705. f._count = 0
  706. res = root_scalar(f, bracket=(0, 1), method=method)
  707. assert res.converged is False
  708. assert res.flag.startswith("The function value at x")
  709. assert res.function_calls == f._count
  710. assert str(res.root) in res.flag
  711. @pytest.mark.parametrize('solver_name',
  712. ['brentq', 'brenth', 'bisect', 'ridder', 'toms748'])
  713. @pytest.mark.parametrize('rs_interface', [True, False])
  714. def test_function_calls(solver_name, rs_interface):
  715. # There do not appear to be checks that the bracketing solvers report the
  716. # correct number of function evaluations. Check that this is the case.
  717. solver = ((lambda f, a, b, **kwargs: root_scalar(f, bracket=(a, b)))
  718. if rs_interface else getattr(zeros, solver_name))
  719. def f(x):
  720. f.calls += 1
  721. return x**2 - 1
  722. f.calls = 0
  723. res = solver(f, 0, 10, full_output=True)
  724. if rs_interface:
  725. assert res.function_calls == f.calls
  726. else:
  727. assert res[1].function_calls == f.calls
  728. def test_gh_14486_converged_false():
  729. """Test that zero slope with secant method results in a converged=False"""
  730. def lhs(x):
  731. return x * np.exp(-x*x) - 0.07
  732. with pytest.warns(RuntimeWarning, match='Tolerance of'):
  733. res = root_scalar(lhs, method='secant', x0=-0.15, x1=1.0)
  734. assert not res.converged
  735. assert res.flag == 'convergence error'
  736. with pytest.warns(RuntimeWarning, match='Tolerance of'):
  737. res = newton(lhs, x0=-0.15, x1=1.0, disp=False, full_output=True)[1]
  738. assert not res.converged
  739. assert res.flag == 'convergence error'
  740. @pytest.mark.parametrize('solver_name',
  741. ['brentq', 'brenth', 'bisect', 'ridder', 'toms748'])
  742. @pytest.mark.parametrize('rs_interface', [True, False])
  743. def test_gh5584(solver_name, rs_interface):
  744. # gh-5584 reported that an underflow can cause sign checks in the algorithm
  745. # to fail. Check that this is resolved.
  746. solver = ((lambda f, a, b, **kwargs: root_scalar(f, bracket=(a, b)))
  747. if rs_interface else getattr(zeros, solver_name))
  748. def f(x):
  749. return 1e-200*x
  750. # Report failure when signs are the same
  751. with pytest.raises(ValueError, match='...must have different signs'):
  752. solver(f, -0.5, -0.4, full_output=True)
  753. # Solve successfully when signs are different
  754. res = solver(f, -0.5, 0.4, full_output=True)
  755. res = res if rs_interface else res[1]
  756. assert res.converged
  757. assert_allclose(res.root, 0, atol=1e-8)
  758. # Solve successfully when one side is negative zero
  759. res = solver(f, -0.5, float('-0.0'), full_output=True)
  760. res = res if rs_interface else res[1]
  761. assert res.converged
  762. assert_allclose(res.root, 0, atol=1e-8)
  763. def test_gh13407():
  764. # gh-13407 reported that the message produced by `scipy.optimize.toms748`
  765. # when `rtol < eps` is incorrect, and also that toms748 is unusual in
  766. # accepting `rtol` as low as eps while other solvers raise at 4*eps. Check
  767. # that the error message has been corrected and that `rtol=eps` can produce
  768. # a lower function value than `rtol=4*eps`.
  769. def f(x):
  770. return x**3 - 2*x - 5
  771. xtol = 1e-300
  772. eps = np.finfo(float).eps
  773. x1 = zeros.toms748(f, 1e-10, 1e10, xtol=xtol, rtol=1*eps)
  774. f1 = f(x1)
  775. x4 = zeros.toms748(f, 1e-10, 1e10, xtol=xtol, rtol=4*eps)
  776. f4 = f(x4)
  777. assert f1 < f4
  778. # using old-style syntax to get exactly the same message
  779. message = fr"rtol too small \({eps/2:g} < {eps:g}\)"
  780. with pytest.raises(ValueError, match=message):
  781. zeros.toms748(f, 1e-10, 1e10, xtol=xtol, rtol=eps/2)
  782. def test_newton_complex_gh10103():
  783. # gh-10103 reported a problem when `newton` is pass a Python complex x0,
  784. # no `fprime` (secant method), and no `x1` (`x1` must be constructed).
  785. # Check that this is resolved.
  786. def f(z):
  787. return z - 1
  788. res = newton(f, 1+1j)
  789. assert_allclose(res, 1, atol=1e-12)
  790. res = root_scalar(f, x0=1+1j, x1=2+1.5j, method='secant')
  791. assert_allclose(res.root, 1, atol=1e-12)
  792. @pytest.mark.parametrize('method', all_methods)
  793. def test_maxiter_int_check_gh10236(method):
  794. # gh-10236 reported that the error message when `maxiter` is not an integer
  795. # was difficult to interpret. Check that this was resolved (by gh-10907).
  796. message = "'float' object cannot be interpreted as an integer"
  797. with pytest.raises(TypeError, match=message):
  798. method(f1, 0.0, 1.0, maxiter=72.45)
  799. @pytest.mark.parametrize("method", [zeros.bisect, zeros.ridder,
  800. zeros.brentq, zeros.brenth])
  801. def test_bisect_special_parameter(method):
  802. # give some zeros method strange parameters
  803. # and check whether an exception appears
  804. root = 0.1
  805. args = (1e-09, 0.004, 10, 0.27456)
  806. rtolbad = 4 * np.finfo(float).eps / 2
  807. def f(x):
  808. return x - root
  809. with pytest.raises(ValueError, match="xtol too small"):
  810. method(f, -1e8, 1e7, args=args, xtol=-1e-6, rtol=TOL)
  811. with pytest.raises(ValueError, match="rtol too small"):
  812. method(f, -1e8, 1e7, args=args, xtol=1e-6, rtol=rtolbad)