_elementwise.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816
  1. from scipy.optimize._bracket import _bracket_root, _bracket_minimum
  2. from scipy.optimize._chandrupatla import _chandrupatla, _chandrupatla_minimize
  3. from scipy._lib._util import _RichResult
  4. from scipy._lib._array_api import xp_capabilities
  5. @xp_capabilities(
  6. skip_backends=[('dask.array', 'boolean indexing assignment'),
  7. ('array_api_strict', 'Currently uses fancy indexing assignment.'),
  8. ('jax.numpy', 'JAX arrays do not support item assignment.')])
  9. def find_root(f, init, /, *, args=(), tolerances=None, maxiter=None, callback=None):
  10. """Find the root of a monotonic, real-valued function of a real variable.
  11. For each element of the output of `f`, `find_root` seeks the scalar
  12. root that makes the element 0. This function currently uses Chandrupatla's
  13. bracketing algorithm [1]_ and therefore requires argument `init` to
  14. provide a bracket around the root: the function values at the two endpoints
  15. must have opposite signs.
  16. Provided a valid bracket, `find_root` is guaranteed to converge to a solution
  17. that satisfies the provided `tolerances` if the function is continuous within
  18. the bracket.
  19. This function works elementwise when `init` and `args` contain (broadcastable)
  20. arrays.
  21. Parameters
  22. ----------
  23. f : callable
  24. The function whose root is desired. The signature must be::
  25. f(x: array, *args) -> array
  26. where each element of ``x`` is a finite real and ``args`` is a tuple,
  27. which may contain an arbitrary number of arrays that are broadcastable
  28. with ``x``.
  29. `f` must be an elementwise function: each element ``f(x)[i]``
  30. must equal ``f(x[i])`` for all indices ``i``. It must not mutate the
  31. array ``x`` or the arrays in ``args``.
  32. `find_root` seeks an array ``x`` such that ``f(x)`` is an array of zeros.
  33. init : 2-tuple of float array_like
  34. The lower and upper endpoints of a bracket surrounding the desired root.
  35. A bracket is valid if arrays ``xl, xr = init`` satisfy ``xl < xr`` and
  36. ``sign(f(xl)) == -sign(f(xr))`` elementwise. Arrays be broadcastable with
  37. one another and `args`.
  38. args : tuple of array_like, optional
  39. Additional positional array arguments to be passed to `f`. Arrays
  40. must be broadcastable with one another and the arrays of `init`.
  41. If the callable for which the root is desired requires arguments that are
  42. not broadcastable with `x`, wrap that callable with `f` such that `f`
  43. accepts only `x` and broadcastable ``*args``.
  44. tolerances : dictionary of floats, optional
  45. Absolute and relative tolerances on the root and function value.
  46. Valid keys of the dictionary are:
  47. - ``xatol`` - absolute tolerance on the root
  48. - ``xrtol`` - relative tolerance on the root
  49. - ``fatol`` - absolute tolerance on the function value
  50. - ``frtol`` - relative tolerance on the function value
  51. See Notes for default values and explicit termination conditions.
  52. maxiter : int, optional
  53. The maximum number of iterations of the algorithm to perform.
  54. The default is the maximum possible number of bisections within
  55. the (normal) floating point numbers of the relevant dtype.
  56. callback : callable, optional
  57. An optional user-supplied function to be called before the first
  58. iteration and after each iteration.
  59. Called as ``callback(res)``, where ``res`` is a ``_RichResult``
  60. similar to that returned by `find_root` (but containing the current
  61. iterate's values of all variables). If `callback` raises a
  62. ``StopIteration``, the algorithm will terminate immediately and
  63. `find_root` will return a result. `callback` must not mutate
  64. `res` or its attributes.
  65. Returns
  66. -------
  67. res : _RichResult
  68. An object similar to an instance of `scipy.optimize.OptimizeResult` with the
  69. following attributes. The descriptions are written as though the values will
  70. be scalars; however, if `f` returns an array, the outputs will be
  71. arrays of the same shape.
  72. success : bool array
  73. ``True`` where the algorithm terminated successfully (status ``0``);
  74. ``False`` otherwise.
  75. status : int array
  76. An integer representing the exit status of the algorithm.
  77. - ``0`` : The algorithm converged to the specified tolerances.
  78. - ``-1`` : The initial bracket was invalid.
  79. - ``-2`` : The maximum number of iterations was reached.
  80. - ``-3`` : A non-finite value was encountered.
  81. - ``-4`` : Iteration was terminated by `callback`.
  82. - ``1`` : The algorithm is proceeding normally (in `callback` only).
  83. x : float array
  84. The root of the function, if the algorithm terminated successfully.
  85. f_x : float array
  86. The value of `f` evaluated at `x`.
  87. nfev : int array
  88. The number of abscissae at which `f` was evaluated to find the root.
  89. This is distinct from the number of times `f` is *called* because the
  90. the function may evaluated at multiple points in a single call.
  91. nit : int array
  92. The number of iterations of the algorithm that were performed.
  93. bracket : tuple of float arrays
  94. The lower and upper endpoints of the final bracket.
  95. f_bracket : tuple of float arrays
  96. The value of `f` evaluated at the lower and upper endpoints of the
  97. bracket.
  98. Notes
  99. -----
  100. Implemented based on Chandrupatla's original paper [1]_.
  101. Let:
  102. - ``a, b = init`` be the left and right endpoints of the initial bracket,
  103. - ``xl`` and ``xr`` be the left and right endpoints of the final bracket,
  104. - ``xmin = xl if abs(f(xl)) <= abs(f(xr)) else xr`` be the final bracket
  105. endpoint with the smaller function value, and
  106. - ``fmin0 = min(f(a), f(b))`` be the minimum of the two values of the
  107. function evaluated at the initial bracket endpoints.
  108. Then the algorithm is considered to have converged when
  109. - ``abs(xr - xl) < xatol + abs(xmin) * xrtol`` or
  110. - ``fun(xmin) <= fatol + abs(fmin0) * frtol``.
  111. This is equivalent to the termination condition described in [1]_ with
  112. ``xrtol = 4e-10``, ``xatol = 1e-5``, and ``fatol = frtol = 0``.
  113. However, the default values of the `tolerances` dictionary are
  114. ``xatol = 4*tiny``, ``xrtol = 4*eps``, ``frtol = 0``, and ``fatol = tiny``,
  115. where ``eps`` and ``tiny`` are the precision and smallest normal number
  116. of the result ``dtype`` of function inputs and outputs.
  117. References
  118. ----------
  119. .. [1] Chandrupatla, Tirupathi R.
  120. "A new hybrid quadratic/bisection algorithm for finding the zero of a
  121. nonlinear function without using derivatives".
  122. Advances in Engineering Software, 28(3), 145-149.
  123. https://doi.org/10.1016/s0965-9978(96)00051-8
  124. See Also
  125. --------
  126. bracket_root
  127. Examples
  128. --------
  129. Suppose we wish to find the root of the following function.
  130. >>> def f(x, c=5):
  131. ... return x**3 - 2*x - c
  132. First, we must find a valid bracket. The function is not monotonic,
  133. but `bracket_root` may be able to provide a bracket.
  134. >>> from scipy.optimize import elementwise
  135. >>> res_bracket = elementwise.bracket_root(f, 0)
  136. >>> res_bracket.success
  137. True
  138. >>> res_bracket.bracket
  139. (2.0, 4.0)
  140. Indeed, the values of the function at the bracket endpoints have
  141. opposite signs.
  142. >>> res_bracket.f_bracket
  143. (-1.0, 51.0)
  144. Once we have a valid bracket, `find_root` can be used to provide
  145. a precise root.
  146. >>> res_root = elementwise.find_root(f, res_bracket.bracket)
  147. >>> res_root.x
  148. 2.0945514815423265
  149. The final bracket is only a few ULPs wide, so the error between
  150. this value and the true root cannot be much smaller within values
  151. that are representable in double precision arithmetic.
  152. >>> import numpy as np
  153. >>> xl, xr = res_root.bracket
  154. >>> (xr - xl) / np.spacing(xl)
  155. 2.0
  156. >>> res_root.f_bracket
  157. (-8.881784197001252e-16, 9.769962616701378e-15)
  158. `bracket_root` and `find_root` accept arrays for most arguments.
  159. For instance, to find the root for a few values of the parameter ``c``
  160. at once:
  161. >>> c = np.asarray([3, 4, 5])
  162. >>> res_bracket = elementwise.bracket_root(f, 0, args=(c,))
  163. >>> res_bracket.bracket
  164. (array([1., 1., 2.]), array([2., 2., 4.]))
  165. >>> res_root = elementwise.find_root(f, res_bracket.bracket, args=(c,))
  166. >>> res_root.x
  167. array([1.8932892 , 2. , 2.09455148])
  168. """
  169. def reformat_result(res_in):
  170. res_out = _RichResult()
  171. res_out.status = res_in.status
  172. res_out.success = res_in.success
  173. res_out.x = res_in.x
  174. res_out.f_x = res_in.fun
  175. res_out.nfev = res_in.nfev
  176. res_out.nit = res_in.nit
  177. res_out.bracket = (res_in.xl, res_in.xr)
  178. res_out.f_bracket = (res_in.fl, res_in.fr)
  179. res_out._order_keys = ['success', 'status', 'x', 'f_x',
  180. 'nfev', 'nit', 'bracket', 'f_bracket']
  181. return res_out
  182. xl, xr = init
  183. default_tolerances = dict(xatol=None, xrtol=None, fatol=None, frtol=0)
  184. tolerances = {} if tolerances is None else tolerances
  185. default_tolerances.update(tolerances)
  186. tolerances = default_tolerances
  187. if callable(callback):
  188. def _callback(res):
  189. return callback(reformat_result(res))
  190. else:
  191. _callback = callback
  192. res = _chandrupatla(f, xl, xr, args=args, **tolerances,
  193. maxiter=maxiter, callback=_callback)
  194. return reformat_result(res)
  195. @xp_capabilities(
  196. skip_backends=[('dask.array', 'boolean indexing assignment'),
  197. ('array_api_strict', 'Currently uses fancy indexing assignment.'),
  198. ('jax.numpy', 'JAX arrays do not support item assignment.')])
  199. def find_minimum(f, init, /, *, args=(), tolerances=None, maxiter=100, callback=None):
  200. """Find the minimum of an unimodal, real-valued function of a real variable.
  201. For each element of the output of `f`, `find_minimum` seeks the scalar minimizer
  202. that minimizes the element. This function currently uses Chandrupatla's
  203. bracketing minimization algorithm [1]_ and therefore requires argument `init`
  204. to provide a three-point minimization bracket: ``x1 < x2 < x3`` such that
  205. ``func(x1) >= func(x2) <= func(x3)``, where one of the inequalities is strict.
  206. Provided a valid bracket, `find_minimum` is guaranteed to converge to a local
  207. minimum that satisfies the provided `tolerances` if the function is continuous
  208. within the bracket.
  209. This function works elementwise when `init` and `args` contain (broadcastable)
  210. arrays.
  211. Parameters
  212. ----------
  213. f : callable
  214. The function whose minimizer is desired. The signature must be::
  215. f(x: array, *args) -> array
  216. where each element of ``x`` is a finite real and ``args`` is a tuple,
  217. which may contain an arbitrary number of arrays that are broadcastable
  218. with ``x``.
  219. `f` must be an elementwise function: each element ``f(x)[i]``
  220. must equal ``f(x[i])`` for all indices ``i``. It must not mutate the
  221. array ``x`` or the arrays in ``args``.
  222. `find_minimum` seeks an array ``x`` such that ``f(x)`` is an array of
  223. local minima.
  224. init : 3-tuple of float array_like
  225. The abscissae of a standard scalar minimization bracket. A bracket is
  226. valid if arrays ``x1, x2, x3 = init`` satisfy ``x1 < x2 < x3`` and
  227. ``func(x1) >= func(x2) <= func(x3)``, where one of the inequalities
  228. is strict. Arrays must be broadcastable with one another and the arrays
  229. of `args`.
  230. args : tuple of array_like, optional
  231. Additional positional array arguments to be passed to `f`. Arrays
  232. must be broadcastable with one another and the arrays of `init`.
  233. If the callable for which the root is desired requires arguments that are
  234. not broadcastable with `x`, wrap that callable with `f` such that `f`
  235. accepts only `x` and broadcastable ``*args``.
  236. tolerances : dictionary of floats, optional
  237. Absolute and relative tolerances on the root and function value.
  238. Valid keys of the dictionary are:
  239. - ``xatol`` - absolute tolerance on the root
  240. - ``xrtol`` - relative tolerance on the root
  241. - ``fatol`` - absolute tolerance on the function value
  242. - ``frtol`` - relative tolerance on the function value
  243. See Notes for default values and explicit termination conditions.
  244. maxiter : int, default: 100
  245. The maximum number of iterations of the algorithm to perform.
  246. callback : callable, optional
  247. An optional user-supplied function to be called before the first
  248. iteration and after each iteration.
  249. Called as ``callback(res)``, where ``res`` is a ``_RichResult``
  250. similar to that returned by `find_minimum` (but containing the current
  251. iterate's values of all variables). If `callback` raises a
  252. ``StopIteration``, the algorithm will terminate immediately and
  253. `find_root` will return a result. `callback` must not mutate
  254. `res` or its attributes.
  255. Returns
  256. -------
  257. res : _RichResult
  258. An object similar to an instance of `scipy.optimize.OptimizeResult` with the
  259. following attributes. The descriptions are written as though the values will
  260. be scalars; however, if `f` returns an array, the outputs will be
  261. arrays of the same shape.
  262. success : bool array
  263. ``True`` where the algorithm terminated successfully (status ``0``);
  264. ``False`` otherwise.
  265. status : int array
  266. An integer representing the exit status of the algorithm.
  267. - ``0`` : The algorithm converged to the specified tolerances.
  268. - ``-1`` : The algorithm encountered an invalid bracket.
  269. - ``-2`` : The maximum number of iterations was reached.
  270. - ``-3`` : A non-finite value was encountered.
  271. - ``-4`` : Iteration was terminated by `callback`.
  272. - ``1`` : The algorithm is proceeding normally (in `callback` only).
  273. x : float array
  274. The minimizer of the function, if the algorithm terminated successfully.
  275. f_x : float array
  276. The value of `f` evaluated at `x`.
  277. nfev : int array
  278. The number of abscissae at which `f` was evaluated to find the root.
  279. This is distinct from the number of times `f` is *called* because the
  280. the function may evaluated at multiple points in a single call.
  281. nit : int array
  282. The number of iterations of the algorithm that were performed.
  283. bracket : tuple of float arrays
  284. The final three-point bracket.
  285. f_bracket : tuple of float arrays
  286. The value of `f` evaluated at the bracket points.
  287. Notes
  288. -----
  289. Implemented based on Chandrupatla's original paper [1]_.
  290. If ``xl < xm < xr`` are the points of the bracket and ``fl >= fm <= fr``
  291. (where one of the inequalities is strict) are the values of `f` evaluated
  292. at those points, then the algorithm is considered to have converged when:
  293. - ``abs(xr - xm)/2 <= abs(xm)*xrtol + xatol`` or
  294. - ``(fl - 2*fm + fr)/2 <= abs(fm)*frtol + fatol``.
  295. The default value of `xrtol` is the square root of the precision of the
  296. appropriate dtype, and ``xatol = fatol = frtol`` is the smallest normal
  297. number of the appropriate dtype.
  298. References
  299. ----------
  300. .. [1] Chandrupatla, Tirupathi R. (1998).
  301. "An efficient quadratic fit-sectioning algorithm for minimization
  302. without derivatives".
  303. Computer Methods in Applied Mechanics and Engineering, 152 (1-2),
  304. 211-217. https://doi.org/10.1016/S0045-7825(97)00190-4
  305. See Also
  306. --------
  307. bracket_minimum
  308. Examples
  309. --------
  310. Suppose we wish to minimize the following function.
  311. >>> def f(x, c=1):
  312. ... return (x - c)**2 + 2
  313. First, we must find a valid bracket. The function is unimodal,
  314. so `bracket_minium` will easily find a bracket.
  315. >>> from scipy.optimize import elementwise
  316. >>> res_bracket = elementwise.bracket_minimum(f, 0)
  317. >>> res_bracket.success
  318. True
  319. >>> res_bracket.bracket
  320. (0.0, 0.5, 1.5)
  321. Indeed, the bracket points are ordered and the function value
  322. at the middle bracket point is less than at the surrounding
  323. points.
  324. >>> xl, xm, xr = res_bracket.bracket
  325. >>> fl, fm, fr = res_bracket.f_bracket
  326. >>> (xl < xm < xr) and (fl > fm <= fr)
  327. True
  328. Once we have a valid bracket, `find_minimum` can be used to provide
  329. an estimate of the minimizer.
  330. >>> res_minimum = elementwise.find_minimum(f, res_bracket.bracket)
  331. >>> res_minimum.x
  332. 1.0000000149011612
  333. The function value changes by only a few ULPs within the bracket, so
  334. the minimizer cannot be determined much more precisely by evaluating
  335. the function alone (i.e. we would need its derivative to do better).
  336. >>> import numpy as np
  337. >>> fl, fm, fr = res_minimum.f_bracket
  338. >>> (fl - fm) / np.spacing(fm), (fr - fm) / np.spacing(fm)
  339. (0.0, 2.0)
  340. Therefore, a precise minimum of the function is given by:
  341. >>> res_minimum.f_x
  342. 2.0
  343. `bracket_minimum` and `find_minimum` accept arrays for most arguments.
  344. For instance, to find the minimizers and minima for a few values of the
  345. parameter ``c`` at once:
  346. >>> c = np.asarray([1, 1.5, 2])
  347. >>> res_bracket = elementwise.bracket_minimum(f, 0, args=(c,))
  348. >>> res_bracket.bracket
  349. (array([0. , 0.5, 0.5]), array([0.5, 1.5, 1.5]), array([1.5, 2.5, 2.5]))
  350. >>> res_minimum = elementwise.find_minimum(f, res_bracket.bracket, args=(c,))
  351. >>> res_minimum.x
  352. array([1.00000001, 1.5 , 2. ])
  353. >>> res_minimum.f_x
  354. array([2., 2., 2.])
  355. """
  356. def reformat_result(res_in):
  357. res_out = _RichResult()
  358. res_out.status = res_in.status
  359. res_out.success = res_in.success
  360. res_out.x = res_in.x
  361. res_out.f_x = res_in.fun
  362. res_out.nfev = res_in.nfev
  363. res_out.nit = res_in.nit
  364. res_out.bracket = (res_in.xl, res_in.xm, res_in.xr)
  365. res_out.f_bracket = (res_in.fl, res_in.fm, res_in.fr)
  366. res_out._order_keys = ['success', 'status', 'x', 'f_x',
  367. 'nfev', 'nit', 'bracket', 'f_bracket']
  368. return res_out
  369. xl, xm, xr = init
  370. default_tolerances = dict(xatol=None, xrtol=None, fatol=None, frtol=None)
  371. tolerances = {} if tolerances is None else tolerances
  372. default_tolerances.update(tolerances)
  373. tolerances = default_tolerances
  374. if callable(callback):
  375. def _callback(res):
  376. return callback(reformat_result(res))
  377. else:
  378. _callback = callback
  379. res = _chandrupatla_minimize(f, xl, xm, xr, args=args, **tolerances,
  380. maxiter=maxiter, callback=_callback)
  381. return reformat_result(res)
  382. @xp_capabilities(
  383. skip_backends=[('dask.array', 'boolean indexing assignment'),
  384. ('array_api_strict', 'Currently uses fancy indexing assignment.'),
  385. ('jax.numpy', 'JAX arrays do not support item assignment.')])
  386. def bracket_root(f, xl0, xr0=None, *, xmin=None, xmax=None, factor=None, args=(),
  387. maxiter=1000):
  388. """Bracket the root of a monotonic, real-valued function of a real variable.
  389. For each element of the output of `f`, `bracket_root` seeks the scalar
  390. bracket endpoints ``xl`` and ``xr`` such that ``sign(f(xl)) == -sign(f(xr))``
  391. elementwise.
  392. The function is guaranteed to find a valid bracket if the function is monotonic,
  393. but it may find a bracket under other conditions.
  394. This function works elementwise when `xl0`, `xr0`, `xmin`, `xmax`, `factor`, and
  395. the elements of `args` are (mutually broadcastable) arrays.
  396. Parameters
  397. ----------
  398. f : callable
  399. The function for which the root is to be bracketed. The signature must be::
  400. f(x: array, *args) -> array
  401. where each element of ``x`` is a finite real and ``args`` is a tuple,
  402. which may contain an arbitrary number of arrays that are broadcastable
  403. with ``x``.
  404. `f` must be an elementwise function: each element ``f(x)[i]``
  405. must equal ``f(x[i])`` for all indices ``i``. It must not mutate the
  406. array ``x`` or the arrays in ``args``.
  407. xl0, xr0: float array_like
  408. Starting guess of bracket, which need not contain a root. If `xr0` is
  409. not provided, ``xr0 = xl0 + 1``. Must be broadcastable with all other
  410. array inputs.
  411. xmin, xmax : float array_like, optional
  412. Minimum and maximum allowable endpoints of the bracket, inclusive. Must
  413. be broadcastable with all other array inputs.
  414. factor : float array_like, default: 2
  415. The factor used to grow the bracket. See Notes.
  416. args : tuple of array_like, optional
  417. Additional positional array arguments to be passed to `f`.
  418. If the callable for which the root is desired requires arguments that are
  419. not broadcastable with `x`, wrap that callable with `f` such that `f`
  420. accepts only `x` and broadcastable ``*args``.
  421. maxiter : int, default: 1000
  422. The maximum number of iterations of the algorithm to perform.
  423. Returns
  424. -------
  425. res : _RichResult
  426. An object similar to an instance of `scipy.optimize.OptimizeResult` with the
  427. following attributes. The descriptions are written as though the values will
  428. be scalars; however, if `f` returns an array, the outputs will be
  429. arrays of the same shape.
  430. success : bool array
  431. ``True`` where the algorithm terminated successfully (status ``0``);
  432. ``False`` otherwise.
  433. status : int array
  434. An integer representing the exit status of the algorithm.
  435. - ``0`` : The algorithm produced a valid bracket.
  436. - ``-1`` : The bracket expanded to the allowable limits without success.
  437. - ``-2`` : The maximum number of iterations was reached.
  438. - ``-3`` : A non-finite value was encountered.
  439. - ``-4`` : Iteration was terminated by `callback`.
  440. - ``-5``: The initial bracket does not satisfy`xmin <= xl0 < xr0 < xmax`.
  441. bracket : 2-tuple of float arrays
  442. The lower and upper endpoints of the bracket, if the algorithm
  443. terminated successfully.
  444. f_bracket : 2-tuple of float arrays
  445. The values of `f` evaluated at the endpoints of ``res.bracket``,
  446. respectively.
  447. nfev : int array
  448. The number of abscissae at which `f` was evaluated to find the root.
  449. This is distinct from the number of times `f` is *called* because the
  450. the function may evaluated at multiple points in a single call.
  451. nit : int array
  452. The number of iterations of the algorithm that were performed.
  453. Notes
  454. -----
  455. This function generalizes an algorithm found in pieces throughout the
  456. `scipy.stats` codebase. The strategy is to iteratively grow the bracket `(l, r)`
  457. until ``f(l) < 0 < f(r)`` or ``f(r) < 0 < f(l)``. The bracket grows to the left
  458. as follows.
  459. - If `xmin` is not provided, the distance between `xl0` and `l` is iteratively
  460. increased by `factor`.
  461. - If `xmin` is provided, the distance between `xmin` and `l` is iteratively
  462. decreased by `factor`. Note that this also *increases* the bracket size.
  463. Growth of the bracket to the right is analogous.
  464. Growth of the bracket in one direction stops when the endpoint is no longer
  465. finite, the function value at the endpoint is no longer finite, or the
  466. endpoint reaches its limiting value (`xmin` or `xmax`). Iteration terminates
  467. when the bracket stops growing in both directions, the bracket surrounds
  468. the root, or a root is found (by chance).
  469. If two brackets are found - that is, a bracket is found on both sides in
  470. the same iteration, the smaller of the two is returned.
  471. If roots of the function are found, both `xl` and `xr` are set to the
  472. leftmost root.
  473. See Also
  474. --------
  475. find_root
  476. Examples
  477. --------
  478. Suppose we wish to find the root of the following function.
  479. >>> def f(x, c=5):
  480. ... return x**3 - 2*x - c
  481. First, we must find a valid bracket. The function is not monotonic,
  482. but `bracket_root` may be able to provide a bracket.
  483. >>> from scipy.optimize import elementwise
  484. >>> res_bracket = elementwise.bracket_root(f, 0)
  485. >>> res_bracket.success
  486. True
  487. >>> res_bracket.bracket
  488. (2.0, 4.0)
  489. Indeed, the values of the function at the bracket endpoints have
  490. opposite signs.
  491. >>> res_bracket.f_bracket
  492. (-1.0, 51.0)
  493. Once we have a valid bracket, `find_root` can be used to provide
  494. a precise root.
  495. >>> res_root = elementwise.find_root(f, res_bracket.bracket)
  496. >>> res_root.x
  497. 2.0945514815423265
  498. `bracket_root` and `find_root` accept arrays for most arguments.
  499. For instance, to find the root for a few values of the parameter ``c``
  500. at once:
  501. >>> import numpy as np
  502. >>> c = np.asarray([3, 4, 5])
  503. >>> res_bracket = elementwise.bracket_root(f, 0, args=(c,))
  504. >>> res_bracket.bracket
  505. (array([1., 1., 2.]), array([2., 2., 4.]))
  506. >>> res_root = elementwise.find_root(f, res_bracket.bracket, args=(c,))
  507. >>> res_root.x
  508. array([1.8932892 , 2. , 2.09455148])
  509. """ # noqa: E501
  510. res = _bracket_root(f, xl0, xr0=xr0, xmin=xmin, xmax=xmax, factor=factor,
  511. args=args, maxiter=maxiter)
  512. res.bracket = res.xl, res.xr
  513. res.f_bracket = res.fl, res.fr
  514. del res.xl
  515. del res.xr
  516. del res.fl
  517. del res.fr
  518. return res
  519. @xp_capabilities(
  520. skip_backends=[('dask.array', 'boolean indexing assignment'),
  521. ('array_api_strict', 'Currently uses fancy indexing assignment.'),
  522. ('jax.numpy', 'JAX arrays do not support item assignment.'),
  523. ('torch', 'data-apis/array-api-compat#271')])
  524. def bracket_minimum(f, xm0, *, xl0=None, xr0=None, xmin=None, xmax=None,
  525. factor=None, args=(), maxiter=1000):
  526. """Bracket the minimum of a unimodal, real-valued function of a real variable.
  527. For each element of the output of `f`, `bracket_minimum` seeks the scalar
  528. bracket points ``xl < xm < xr`` such that ``fl >= fm <= fr`` where one of the
  529. inequalities is strict.
  530. The function is guaranteed to find a valid bracket if the function is
  531. strongly unimodal, but it may find a bracket under other conditions.
  532. This function works elementwise when `xm0`, `xl0`, `xr0`, `xmin`, `xmax`, `factor`,
  533. and the elements of `args` are (mutually broadcastable) arrays.
  534. Parameters
  535. ----------
  536. f : callable
  537. The function for which the root is to be bracketed. The signature must be::
  538. f(x: array, *args) -> array
  539. where each element of ``x`` is a finite real and ``args`` is a tuple,
  540. which may contain an arbitrary number of arrays that are broadcastable
  541. with ``x``.
  542. `f` must be an elementwise function: each element ``f(x)[i]``
  543. must equal ``f(x[i])`` for all indices ``i``. It must not mutate the
  544. array ``x`` or the arrays in ``args``.
  545. xm0: float array_like
  546. Starting guess for middle point of bracket.
  547. xl0, xr0: float array_like, optional
  548. Starting guesses for left and right endpoints of the bracket. Must
  549. be broadcastable with all other array inputs.
  550. xmin, xmax : float array_like, optional
  551. Minimum and maximum allowable endpoints of the bracket, inclusive. Must
  552. be broadcastable with all other array inputs.
  553. factor : float array_like, default: 2
  554. The factor used to grow the bracket. See Notes.
  555. args : tuple of array_like, optional
  556. Additional positional array arguments to be passed to `f`.
  557. If the callable for which the root is desired requires arguments that are
  558. not broadcastable with `x`, wrap that callable with `f` such that `f`
  559. accepts only `x` and broadcastable ``*args``.
  560. maxiter : int, default: 1000
  561. The maximum number of iterations of the algorithm to perform.
  562. Returns
  563. -------
  564. res : _RichResult
  565. An object similar to an instance of `scipy.optimize.OptimizeResult` with the
  566. following attributes. The descriptions are written as though the values will
  567. be scalars; however, if `f` returns an array, the outputs will be
  568. arrays of the same shape.
  569. success : bool array
  570. ``True`` where the algorithm terminated successfully (status ``0``);
  571. ``False`` otherwise.
  572. status : int array
  573. An integer representing the exit status of the algorithm.
  574. - ``0`` : The algorithm produced a valid bracket.
  575. - ``-1`` : The bracket expanded to the allowable limits. Assuming
  576. unimodality, this implies the endpoint at the limit is a minimizer.
  577. - ``-2`` : The maximum number of iterations was reached.
  578. - ``-3`` : A non-finite value was encountered.
  579. - ``-4`` : ``None`` shall pass.
  580. - ``-5`` : The initial bracket does not satisfy
  581. `xmin <= xl0 < xm0 < xr0 <= xmax`.
  582. bracket : 3-tuple of float arrays
  583. The left, middle, and right points of the bracket, if the algorithm
  584. terminated successfully.
  585. f_bracket : 3-tuple of float arrays
  586. The function value at the left, middle, and right points of the bracket.
  587. nfev : int array
  588. The number of abscissae at which `f` was evaluated to find the root.
  589. This is distinct from the number of times `f` is *called* because the
  590. the function may evaluated at multiple points in a single call.
  591. nit : int array
  592. The number of iterations of the algorithm that were performed.
  593. Notes
  594. -----
  595. Similar to `scipy.optimize.bracket`, this function seeks to find real
  596. points ``xl < xm < xr`` such that ``f(xl) >= f(xm)`` and ``f(xr) >= f(xm)``,
  597. where at least one of the inequalities is strict. Unlike `scipy.optimize.bracket`,
  598. this function can operate in a vectorized manner on array input, so long as
  599. the input arrays are broadcastable with each other. Also unlike
  600. `scipy.optimize.bracket`, users may specify minimum and maximum endpoints
  601. for the desired bracket.
  602. Given an initial trio of points ``xl = xl0``, ``xm = xm0``, ``xr = xr0``,
  603. the algorithm checks if these points already give a valid bracket. If not,
  604. a new endpoint, ``w`` is chosen in the "downhill" direction, ``xm`` becomes the new
  605. opposite endpoint, and either `xl` or `xr` becomes the new middle point,
  606. depending on which direction is downhill. The algorithm repeats from here.
  607. The new endpoint `w` is chosen differently depending on whether or not a
  608. boundary `xmin` or `xmax` has been set in the downhill direction. Without
  609. loss of generality, suppose the downhill direction is to the right, so that
  610. ``f(xl) > f(xm) > f(xr)``. If there is no boundary to the right, then `w`
  611. is chosen to be ``xr + factor * (xr - xm)`` where `factor` is controlled by
  612. the user (defaults to 2.0) so that step sizes increase in geometric proportion.
  613. If there is a boundary, `xmax` in this case, then `w` is chosen to be
  614. ``xmax - (xmax - xr)/factor``, with steps slowing to a stop at
  615. `xmax`. This cautious approach ensures that a minimum near but distinct from
  616. the boundary isn't missed while also detecting whether or not the `xmax` is
  617. a minimizer when `xmax` is reached after a finite number of steps.
  618. See Also
  619. --------
  620. scipy.optimize.bracket
  621. scipy.optimize.elementwise.find_minimum
  622. Examples
  623. --------
  624. Suppose we wish to minimize the following function.
  625. >>> def f(x, c=1):
  626. ... return (x - c)**2 + 2
  627. First, we must find a valid bracket. The function is unimodal,
  628. so `bracket_minium` will easily find a bracket.
  629. >>> from scipy.optimize import elementwise
  630. >>> res_bracket = elementwise.bracket_minimum(f, 0)
  631. >>> res_bracket.success
  632. True
  633. >>> res_bracket.bracket
  634. (0.0, 0.5, 1.5)
  635. Indeed, the bracket points are ordered and the function value
  636. at the middle bracket point is less than at the surrounding
  637. points.
  638. >>> xl, xm, xr = res_bracket.bracket
  639. >>> fl, fm, fr = res_bracket.f_bracket
  640. >>> (xl < xm < xr) and (fl > fm <= fr)
  641. True
  642. Once we have a valid bracket, `find_minimum` can be used to provide
  643. an estimate of the minimizer.
  644. >>> res_minimum = elementwise.find_minimum(f, res_bracket.bracket)
  645. >>> res_minimum.x
  646. 1.0000000149011612
  647. `bracket_minimum` and `find_minimum` accept arrays for most arguments.
  648. For instance, to find the minimizers and minima for a few values of the
  649. parameter ``c`` at once:
  650. >>> import numpy as np
  651. >>> c = np.asarray([1, 1.5, 2])
  652. >>> res_bracket = elementwise.bracket_minimum(f, 0, args=(c,))
  653. >>> res_bracket.bracket
  654. (array([0. , 0.5, 0.5]), array([0.5, 1.5, 1.5]), array([1.5, 2.5, 2.5]))
  655. >>> res_minimum = elementwise.find_minimum(f, res_bracket.bracket, args=(c,))
  656. >>> res_minimum.x
  657. array([1.00000001, 1.5 , 2. ])
  658. >>> res_minimum.f_x
  659. array([2., 2., 2.])
  660. """ # noqa: E501
  661. res = _bracket_minimum(f, xm0, xl0=xl0, xr0=xr0, xmin=xmin, xmax=xmax,
  662. factor=factor, args=args, maxiter=maxiter)
  663. res.bracket = res.xl, res.xm, res.xr
  664. res.f_bracket = res.fl, res.fm, res.fr
  665. del res.xl
  666. del res.xm
  667. del res.xr
  668. del res.fl
  669. del res.fm
  670. del res.fr
  671. return res