_root_scalar.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. """
  2. Unified interfaces to root finding algorithms for real or complex
  3. scalar functions.
  4. Functions
  5. ---------
  6. - root : find a root of a scalar function.
  7. """
  8. import numpy as np
  9. from . import _zeros_py as optzeros
  10. from ._numdiff import approx_derivative
  11. __all__ = ['root_scalar']
  12. ROOT_SCALAR_METHODS = ['bisect', 'brentq', 'brenth', 'ridder', 'toms748',
  13. 'newton', 'secant', 'halley']
  14. class MemoizeDer:
  15. """Decorator that caches the value and derivative(s) of function each
  16. time it is called.
  17. This is a simplistic memoizer that calls and caches a single value
  18. of ``f(x, *args)``.
  19. It assumes that `args` does not change between invocations.
  20. It supports the use case of a root-finder where `args` is fixed,
  21. `x` changes, and only rarely, if at all, does x assume the same value
  22. more than once."""
  23. def __init__(self, fun):
  24. self.fun = fun
  25. self.vals = None
  26. self.x = None
  27. self.n_calls = 0
  28. def __call__(self, x, *args):
  29. r"""Calculate f or use cached value if available"""
  30. # Derivative may be requested before the function itself, always check
  31. if self.vals is None or x != self.x:
  32. fg = self.fun(x, *args)
  33. self.x = x
  34. self.n_calls += 1
  35. self.vals = fg[:]
  36. return self.vals[0]
  37. def fprime(self, x, *args):
  38. r"""Calculate f' or use a cached value if available"""
  39. if self.vals is None or x != self.x:
  40. self(x, *args)
  41. return self.vals[1]
  42. def fprime2(self, x, *args):
  43. r"""Calculate f'' or use a cached value if available"""
  44. if self.vals is None or x != self.x:
  45. self(x, *args)
  46. return self.vals[2]
  47. def ncalls(self):
  48. return self.n_calls
  49. def root_scalar(f, args=(), method=None, bracket=None,
  50. fprime=None, fprime2=None,
  51. x0=None, x1=None,
  52. xtol=None, rtol=None, maxiter=None,
  53. options=None):
  54. """
  55. Find a root of a scalar function.
  56. Parameters
  57. ----------
  58. f : callable
  59. A function to find a root of.
  60. Suppose the callable has signature ``f0(x, *my_args, **my_kwargs)``, where
  61. ``my_args`` and ``my_kwargs`` are required positional and keyword arguments.
  62. Rather than passing ``f0`` as the callable, wrap it to accept
  63. only ``x``; e.g., pass ``fun=lambda x: f0(x, *my_args, **my_kwargs)`` as the
  64. callable, where ``my_args`` (tuple) and ``my_kwargs`` (dict) have been
  65. gathered before invoking this function.
  66. args : tuple, optional
  67. Extra arguments passed to the objective function and its derivative(s).
  68. method : str, optional
  69. Type of solver. Should be one of
  70. - 'bisect' :ref:`(see here) <optimize.root_scalar-bisect>`
  71. - 'brentq' :ref:`(see here) <optimize.root_scalar-brentq>`
  72. - 'brenth' :ref:`(see here) <optimize.root_scalar-brenth>`
  73. - 'ridder' :ref:`(see here) <optimize.root_scalar-ridder>`
  74. - 'toms748' :ref:`(see here) <optimize.root_scalar-toms748>`
  75. - 'newton' :ref:`(see here) <optimize.root_scalar-newton>`
  76. - 'secant' :ref:`(see here) <optimize.root_scalar-secant>`
  77. - 'halley' :ref:`(see here) <optimize.root_scalar-halley>`
  78. bracket: A sequence of 2 floats, optional
  79. An interval bracketing a root. ``f(x, *args)`` must have different
  80. signs at the two endpoints.
  81. x0 : float, optional
  82. Initial guess.
  83. x1 : float, optional
  84. A second guess.
  85. fprime : bool or callable, optional
  86. If `fprime` is a boolean and is True, `f` is assumed to return the
  87. value of the objective function and of the derivative.
  88. `fprime` can also be a callable returning the derivative of `f`. In
  89. this case, it must accept the same arguments as `f`.
  90. fprime2 : bool or callable, optional
  91. If `fprime2` is a boolean and is True, `f` is assumed to return the
  92. value of the objective function and of the
  93. first and second derivatives.
  94. `fprime2` can also be a callable returning the second derivative of `f`.
  95. In this case, it must accept the same arguments as `f`.
  96. xtol : float, optional
  97. Tolerance (absolute) for termination.
  98. rtol : float, optional
  99. Tolerance (relative) for termination.
  100. maxiter : int, optional
  101. Maximum number of iterations.
  102. options : dict, optional
  103. A dictionary of solver options. E.g., ``k``, see
  104. :obj:`show_options()` for details.
  105. Returns
  106. -------
  107. sol : RootResults
  108. The solution represented as a ``RootResults`` object.
  109. Important attributes are: ``root`` the solution , ``converged`` a
  110. boolean flag indicating if the algorithm exited successfully and
  111. ``flag`` which describes the cause of the termination. See
  112. `RootResults` for a description of other attributes.
  113. See also
  114. --------
  115. show_options : Additional options accepted by the solvers
  116. root : Find a root of a vector function.
  117. Notes
  118. -----
  119. This section describes the available solvers that can be selected by the
  120. 'method' parameter.
  121. The default is to use the best method available for the situation
  122. presented.
  123. If a bracket is provided, it may use one of the bracketing methods.
  124. If a derivative and an initial value are specified, it may
  125. select one of the derivative-based methods.
  126. If no method is judged applicable, it will raise an Exception.
  127. Arguments for each method are as follows (x=required, o=optional).
  128. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  129. | method | f | args | bracket | x0 | x1 | fprime | fprime2 | xtol | rtol | maxiter | options |
  130. +===============================================+===+======+=========+====+====+========+=========+======+======+=========+=========+
  131. | :ref:`bisect <optimize.root_scalar-bisect>` | x | o | x | | | | | o | o | o | o |
  132. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  133. | :ref:`brentq <optimize.root_scalar-brentq>` | x | o | x | | | | | o | o | o | o |
  134. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  135. | :ref:`brenth <optimize.root_scalar-brenth>` | x | o | x | | | | | o | o | o | o |
  136. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  137. | :ref:`ridder <optimize.root_scalar-ridder>` | x | o | x | | | | | o | o | o | o |
  138. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  139. | :ref:`toms748 <optimize.root_scalar-toms748>` | x | o | x | | | | | o | o | o | o |
  140. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  141. | :ref:`secant <optimize.root_scalar-secant>` | x | o | | x | o | | | o | o | o | o |
  142. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  143. | :ref:`newton <optimize.root_scalar-newton>` | x | o | | x | | o | | o | o | o | o |
  144. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  145. | :ref:`halley <optimize.root_scalar-halley>` | x | o | | x | | x | x | o | o | o | o |
  146. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  147. Examples
  148. --------
  149. Find the root of a simple cubic
  150. >>> from scipy import optimize
  151. >>> def f(x):
  152. ... return (x**3 - 1) # only one real root at x = 1
  153. >>> def fprime(x):
  154. ... return 3*x**2
  155. The `brentq` method takes as input a bracket
  156. >>> sol = optimize.root_scalar(f, bracket=[0, 3], method='brentq')
  157. >>> sol.root, sol.iterations, sol.function_calls
  158. (1.0, 10, 11)
  159. The `newton` method takes as input a single point and uses the
  160. derivative(s).
  161. >>> sol = optimize.root_scalar(f, x0=0.2, fprime=fprime, method='newton')
  162. >>> sol.root, sol.iterations, sol.function_calls
  163. (1.0, 11, 22)
  164. The function can provide the value and derivative(s) in a single call.
  165. >>> def f_p_pp(x):
  166. ... return (x**3 - 1), 3*x**2, 6*x
  167. >>> sol = optimize.root_scalar(
  168. ... f_p_pp, x0=0.2, fprime=True, method='newton'
  169. ... )
  170. >>> sol.root, sol.iterations, sol.function_calls
  171. (1.0, 11, 11)
  172. >>> sol = optimize.root_scalar(
  173. ... f_p_pp, x0=0.2, fprime=True, fprime2=True, method='halley'
  174. ... )
  175. >>> sol.root, sol.iterations, sol.function_calls
  176. (1.0, 7, 8)
  177. """ # noqa: E501
  178. if not isinstance(args, tuple):
  179. args = (args,)
  180. if options is None:
  181. options = {}
  182. # fun also returns the derivative(s)
  183. is_memoized = False
  184. if fprime2 is not None and not callable(fprime2):
  185. if bool(fprime2):
  186. f = MemoizeDer(f)
  187. is_memoized = True
  188. fprime2 = f.fprime2
  189. fprime = f.fprime
  190. else:
  191. fprime2 = None
  192. if fprime is not None and not callable(fprime):
  193. if bool(fprime):
  194. f = MemoizeDer(f)
  195. is_memoized = True
  196. fprime = f.fprime
  197. else:
  198. fprime = None
  199. # respect solver-specific default tolerances - only pass in if actually set
  200. kwargs = {}
  201. for k in ['xtol', 'rtol', 'maxiter']:
  202. v = locals().get(k)
  203. if v is not None:
  204. kwargs[k] = v
  205. # Set any solver-specific options
  206. if options:
  207. kwargs.update(options)
  208. # Always request full_output from the underlying method as _root_scalar
  209. # always returns a RootResults object
  210. kwargs.update(full_output=True, disp=False)
  211. # Pick a method if not specified.
  212. # Use the "best" method available for the situation.
  213. if not method:
  214. if bracket is not None:
  215. method = 'brentq'
  216. elif x0 is not None:
  217. if fprime:
  218. if fprime2:
  219. method = 'halley'
  220. else:
  221. method = 'newton'
  222. elif x1 is not None:
  223. method = 'secant'
  224. else:
  225. method = 'newton'
  226. if not method:
  227. raise ValueError('Unable to select a solver as neither bracket '
  228. 'nor starting point provided.')
  229. meth = method.lower()
  230. map2underlying = {'halley': 'newton', 'secant': 'newton'}
  231. try:
  232. methodc = getattr(optzeros, map2underlying.get(meth, meth))
  233. except AttributeError as e:
  234. raise ValueError(f'Unknown solver {meth}') from e
  235. if meth in ['bisect', 'ridder', 'brentq', 'brenth', 'toms748']:
  236. if not isinstance(bracket, list | tuple | np.ndarray):
  237. raise ValueError(f'Bracket needed for {method}')
  238. a, b = bracket[:2]
  239. try:
  240. r, sol = methodc(f, a, b, args=args, **kwargs)
  241. except ValueError as e:
  242. # gh-17622 fixed some bugs in low-level solvers by raising an error
  243. # (rather than returning incorrect results) when the callable
  244. # returns a NaN. It did so by wrapping the callable rather than
  245. # modifying compiled code, so the iteration count is not available.
  246. if hasattr(e, "_x"):
  247. sol = optzeros.RootResults(root=e._x,
  248. iterations=np.nan,
  249. function_calls=e._function_calls,
  250. flag=str(e), method=method)
  251. else:
  252. raise
  253. elif meth in ['secant']:
  254. if x0 is None:
  255. raise ValueError(f'x0 must not be None for {method}')
  256. if 'xtol' in kwargs:
  257. kwargs['tol'] = kwargs.pop('xtol')
  258. r, sol = methodc(f, x0, args=args, fprime=None, fprime2=None,
  259. x1=x1, **kwargs)
  260. elif meth in ['newton']:
  261. if x0 is None:
  262. raise ValueError(f'x0 must not be None for {method}')
  263. if not fprime:
  264. # approximate fprime with finite differences
  265. def fprime(x, *args):
  266. # `root_scalar` doesn't actually seem to support vectorized
  267. # use of `newton`. In that case, `approx_derivative` will
  268. # always get scalar input. Nonetheless, it always returns an
  269. # array, so we extract the element to produce scalar output.
  270. # Similarly, `approx_derivative` always passes array input, so
  271. # we extract the element to ensure the user's function gets
  272. # scalar input.
  273. def f_wrapped(x, *args):
  274. return f(x[0], *args)
  275. return approx_derivative(f_wrapped, x, method='2-point', args=args)[0]
  276. if 'xtol' in kwargs:
  277. kwargs['tol'] = kwargs.pop('xtol')
  278. r, sol = methodc(f, x0, args=args, fprime=fprime, fprime2=None,
  279. **kwargs)
  280. elif meth in ['halley']:
  281. if x0 is None:
  282. raise ValueError(f'x0 must not be None for {method}')
  283. if not fprime:
  284. raise ValueError(f'fprime must be specified for {method}')
  285. if not fprime2:
  286. raise ValueError(f'fprime2 must be specified for {method}')
  287. if 'xtol' in kwargs:
  288. kwargs['tol'] = kwargs.pop('xtol')
  289. r, sol = methodc(f, x0, args=args, fprime=fprime, fprime2=fprime2, **kwargs)
  290. else:
  291. raise ValueError(f'Unknown solver {method}')
  292. if is_memoized:
  293. # Replace the function_calls count with the memoized count.
  294. # Avoids double and triple-counting.
  295. n_calls = f.n_calls
  296. sol.function_calls = n_calls
  297. return sol
  298. def _root_scalar_brentq_doc():
  299. r"""
  300. Options
  301. -------
  302. args : tuple, optional
  303. Extra arguments passed to the objective function.
  304. bracket: A sequence of 2 floats, optional
  305. An interval bracketing a root. ``f(x, *args)`` must have different
  306. signs at the two endpoints.
  307. xtol : float, optional
  308. Tolerance (absolute) for termination.
  309. rtol : float, optional
  310. Tolerance (relative) for termination.
  311. maxiter : int, optional
  312. Maximum number of iterations.
  313. options: dict, optional
  314. Specifies any method-specific options not covered above
  315. """
  316. pass
  317. def _root_scalar_brenth_doc():
  318. r"""
  319. Options
  320. -------
  321. args : tuple, optional
  322. Extra arguments passed to the objective function.
  323. bracket: A sequence of 2 floats, optional
  324. An interval bracketing a root. ``f(x, *args)`` must have different
  325. signs at the two endpoints.
  326. xtol : float, optional
  327. Tolerance (absolute) for termination.
  328. rtol : float, optional
  329. Tolerance (relative) for termination.
  330. maxiter : int, optional
  331. Maximum number of iterations.
  332. options: dict, optional
  333. Specifies any method-specific options not covered above.
  334. """
  335. pass
  336. def _root_scalar_toms748_doc():
  337. r"""
  338. Options
  339. -------
  340. args : tuple, optional
  341. Extra arguments passed to the objective function.
  342. bracket: A sequence of 2 floats, optional
  343. An interval bracketing a root. ``f(x, *args)`` must have different
  344. signs at the two endpoints.
  345. xtol : float, optional
  346. Tolerance (absolute) for termination.
  347. rtol : float, optional
  348. Tolerance (relative) for termination.
  349. maxiter : int, optional
  350. Maximum number of iterations.
  351. options: dict, optional
  352. Specifies any method-specific options not covered above.
  353. """
  354. pass
  355. def _root_scalar_secant_doc():
  356. r"""
  357. Options
  358. -------
  359. args : tuple, optional
  360. Extra arguments passed to the objective function.
  361. xtol : float, optional
  362. Tolerance (absolute) for termination.
  363. rtol : float, optional
  364. Tolerance (relative) for termination.
  365. maxiter : int, optional
  366. Maximum number of iterations.
  367. x0 : float, required
  368. Initial guess.
  369. x1 : float, optional
  370. A second guess. Must be different from `x0`. If not specified,
  371. a value near `x0` will be chosen.
  372. options: dict, optional
  373. Specifies any method-specific options not covered above.
  374. """
  375. pass
  376. def _root_scalar_newton_doc():
  377. r"""
  378. Options
  379. -------
  380. args : tuple, optional
  381. Extra arguments passed to the objective function and its derivative.
  382. xtol : float, optional
  383. Tolerance (absolute) for termination.
  384. rtol : float, optional
  385. Tolerance (relative) for termination.
  386. maxiter : int, optional
  387. Maximum number of iterations.
  388. x0 : float, required
  389. Initial guess.
  390. fprime : bool or callable, optional
  391. If `fprime` is a boolean and is True, `f` is assumed to return the
  392. value of derivative along with the objective function.
  393. `fprime` can also be a callable returning the derivative of `f`. In
  394. this case, it must accept the same arguments as `f`.
  395. options: dict, optional
  396. Specifies any method-specific options not covered above.
  397. """
  398. pass
  399. def _root_scalar_halley_doc():
  400. r"""
  401. Options
  402. -------
  403. args : tuple, optional
  404. Extra arguments passed to the objective function and its derivatives.
  405. xtol : float, optional
  406. Tolerance (absolute) for termination.
  407. rtol : float, optional
  408. Tolerance (relative) for termination.
  409. maxiter : int, optional
  410. Maximum number of iterations.
  411. x0 : float, required
  412. Initial guess.
  413. fprime : bool or callable, required
  414. If `fprime` is a boolean and is True, `f` is assumed to return the
  415. value of derivative along with the objective function.
  416. `fprime` can also be a callable returning the derivative of `f`. In
  417. this case, it must accept the same arguments as `f`.
  418. fprime2 : bool or callable, required
  419. If `fprime2` is a boolean and is True, `f` is assumed to return the
  420. value of 1st and 2nd derivatives along with the objective function.
  421. `fprime2` can also be a callable returning the 2nd derivative of `f`.
  422. In this case, it must accept the same arguments as `f`.
  423. options: dict, optional
  424. Specifies any method-specific options not covered above.
  425. """
  426. pass
  427. def _root_scalar_ridder_doc():
  428. r"""
  429. Options
  430. -------
  431. args : tuple, optional
  432. Extra arguments passed to the objective function.
  433. bracket: A sequence of 2 floats, optional
  434. An interval bracketing a root. ``f(x, *args)`` must have different
  435. signs at the two endpoints.
  436. xtol : float, optional
  437. Tolerance (absolute) for termination.
  438. rtol : float, optional
  439. Tolerance (relative) for termination.
  440. maxiter : int, optional
  441. Maximum number of iterations.
  442. options: dict, optional
  443. Specifies any method-specific options not covered above.
  444. """
  445. pass
  446. def _root_scalar_bisect_doc():
  447. r"""
  448. Options
  449. -------
  450. args : tuple, optional
  451. Extra arguments passed to the objective function.
  452. bracket: A sequence of 2 floats, optional
  453. An interval bracketing a root. ``f(x, *args)`` must have different
  454. signs at the two endpoints.
  455. xtol : float, optional
  456. Tolerance (absolute) for termination.
  457. rtol : float, optional
  458. Tolerance (relative) for termination.
  459. maxiter : int, optional
  460. Maximum number of iterations.
  461. options: dict, optional
  462. Specifies any method-specific options not covered above.
  463. """
  464. pass