_chandrupatla.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. import math
  2. import numpy as np
  3. import scipy._lib._elementwise_iterative_method as eim
  4. from scipy._lib._util import _RichResult
  5. from scipy._lib._array_api import xp_copy
  6. # TODO:
  7. # - (maybe?) don't use fancy indexing assignment
  8. # - figure out how to replace the new `try`/`except`s
  9. def _chandrupatla(func, a, b, *, args=(), xatol=None, xrtol=None,
  10. fatol=None, frtol=0, maxiter=None, callback=None):
  11. """Find the root of an elementwise function using Chandrupatla's algorithm.
  12. For each element of the output of `func`, `chandrupatla` seeks the scalar
  13. root that makes the element 0. This function allows for `a`, `b`, and the
  14. output of `func` to be of any broadcastable shapes.
  15. Parameters
  16. ----------
  17. func : callable
  18. The function whose root is desired. The signature must be::
  19. func(x: ndarray, *args) -> ndarray
  20. where each element of ``x`` is a finite real and ``args`` is a tuple,
  21. which may contain an arbitrary number of components of any type(s).
  22. ``func`` must be an elementwise function: each element ``func(x)[i]``
  23. must equal ``func(x[i])`` for all indices ``i``. `_chandrupatla`
  24. seeks an array ``x`` such that ``func(x)`` is an array of zeros.
  25. a, b : array_like
  26. The lower and upper bounds of the root of the function. Must be
  27. broadcastable with one another.
  28. args : tuple, optional
  29. Additional positional arguments to be passed to `func`.
  30. xatol, xrtol, fatol, frtol : float, optional
  31. Absolute and relative tolerances on the root and function value.
  32. See Notes for details.
  33. maxiter : int, optional
  34. The maximum number of iterations of the algorithm to perform.
  35. The default is the maximum possible number of bisections within
  36. the (normal) floating point numbers of the relevant dtype.
  37. callback : callable, optional
  38. An optional user-supplied function to be called before the first
  39. iteration and after each iteration.
  40. Called as ``callback(res)``, where ``res`` is a ``_RichResult``
  41. similar to that returned by `_chandrupatla` (but containing the current
  42. iterate's values of all variables). If `callback` raises a
  43. ``StopIteration``, the algorithm will terminate immediately and
  44. `_chandrupatla` will return a result.
  45. Returns
  46. -------
  47. res : _RichResult
  48. An instance of `scipy._lib._util._RichResult` with the following
  49. attributes. The descriptions are written as though the values will be
  50. scalars; however, if `func` returns an array, the outputs will be
  51. arrays of the same shape.
  52. x : float
  53. The root of the function, if the algorithm terminated successfully.
  54. nfev : int
  55. The number of times the function was called to find the root.
  56. nit : int
  57. The number of iterations of Chandrupatla's algorithm performed.
  58. status : int
  59. An integer representing the exit status of the algorithm.
  60. ``0`` : The algorithm converged to the specified tolerances.
  61. ``-1`` : The algorithm encountered an invalid bracket.
  62. ``-2`` : The maximum number of iterations was reached.
  63. ``-3`` : A non-finite value was encountered.
  64. ``-4`` : Iteration was terminated by `callback`.
  65. ``1`` : The algorithm is proceeding normally (in `callback` only).
  66. success : bool
  67. ``True`` when the algorithm terminated successfully (status ``0``).
  68. fun : float
  69. The value of `func` evaluated at `x`.
  70. xl, xr : float
  71. The lower and upper ends of the bracket.
  72. fl, fr : float
  73. The function value at the lower and upper ends of the bracket.
  74. Notes
  75. -----
  76. Implemented based on Chandrupatla's original paper [1]_.
  77. If ``xl`` and ``xr`` are the left and right ends of the bracket,
  78. ``xmin = xl if abs(func(xl)) <= abs(func(xr)) else xr``,
  79. and ``fmin0 = min(func(a), func(b))``, then the algorithm is considered to
  80. have converged when ``abs(xr - xl) < xatol + abs(xmin) * xrtol`` or
  81. ``fun(xmin) <= fatol + abs(fmin0) * frtol``. This is equivalent to the
  82. termination condition described in [1]_ with ``xrtol = 4e-10``,
  83. ``xatol = 1e-5``, and ``fatol = frtol = 0``. The default values are
  84. ``xatol = 4*tiny``, ``xrtol = 4*eps``, ``frtol = 0``, and ``fatol = tiny``,
  85. where ``eps`` and ``tiny`` are the precision and smallest normal number
  86. of the result ``dtype`` of function inputs and outputs.
  87. References
  88. ----------
  89. .. [1] Chandrupatla, Tirupathi R.
  90. "A new hybrid quadratic/bisection algorithm for finding the zero of a
  91. nonlinear function without using derivatives".
  92. Advances in Engineering Software, 28(3), 145-149.
  93. https://doi.org/10.1016/s0965-9978(96)00051-8
  94. See Also
  95. --------
  96. brentq, brenth, ridder, bisect, newton
  97. Examples
  98. --------
  99. >>> from scipy import optimize
  100. >>> def f(x, c):
  101. ... return x**3 - 2*x - c
  102. >>> c = 5
  103. >>> res = optimize._chandrupatla._chandrupatla(f, 0, 3, args=(c,))
  104. >>> res.x
  105. 2.0945514818937463
  106. >>> c = [3, 4, 5]
  107. >>> res = optimize._chandrupatla._chandrupatla(f, 0, 3, args=(c,))
  108. >>> res.x
  109. array([1.8932892 , 2. , 2.09455148])
  110. """
  111. res = _chandrupatla_iv(func, args, xatol, xrtol,
  112. fatol, frtol, maxiter, callback)
  113. func, args, xatol, xrtol, fatol, frtol, maxiter, callback = res
  114. # Initialization
  115. temp = eim._initialize(func, (a, b), args)
  116. func, xs, fs, args, shape, dtype, xp = temp
  117. x1, x2 = xs
  118. f1, f2 = fs
  119. status = xp.full_like(x1, eim._EINPROGRESS,
  120. dtype=xp.int32) # in progress
  121. nit, nfev = 0, 2 # two function evaluations performed above
  122. finfo = xp.finfo(dtype)
  123. xatol = 4*finfo.smallest_normal if xatol is None else xatol
  124. xrtol = 4*finfo.eps if xrtol is None else xrtol
  125. fatol = finfo.smallest_normal if fatol is None else fatol
  126. frtol = frtol * xp.minimum(xp.abs(f1), xp.abs(f2))
  127. maxiter = (math.log2(finfo.max) - math.log2(finfo.smallest_normal)
  128. if maxiter is None else maxiter)
  129. work = _RichResult(x1=x1, f1=f1, x2=x2, f2=f2, x3=None, f3=None, t=0.5,
  130. xatol=xatol, xrtol=xrtol, fatol=fatol, frtol=frtol,
  131. nit=nit, nfev=nfev, status=status)
  132. res_work_pairs = [('status', 'status'), ('x', 'xmin'), ('fun', 'fmin'),
  133. ('nit', 'nit'), ('nfev', 'nfev'), ('xl', 'x1'),
  134. ('fl', 'f1'), ('xr', 'x2'), ('fr', 'f2')]
  135. def pre_func_eval(work):
  136. # [1] Figure 1 (first box)
  137. x = work.x1 + work.t * (work.x2 - work.x1)
  138. return x
  139. def post_func_eval(x, f, work):
  140. # [1] Figure 1 (first diamond and boxes)
  141. # Note: y/n are reversed in figure; compare to BASIC in appendix
  142. work.x3, work.f3 = (xp.asarray(work.x2, copy=True),
  143. xp.asarray(work.f2, copy=True))
  144. j = xp.sign(f) == xp.sign(work.f1)
  145. nj = ~j
  146. work.x3[j], work.f3[j] = work.x1[j], work.f1[j]
  147. work.x2[nj], work.f2[nj] = work.x1[nj], work.f1[nj]
  148. work.x1, work.f1 = x, f
  149. def check_termination(work):
  150. # [1] Figure 1 (second diamond)
  151. # Check for all terminal conditions and record statuses.
  152. # See [1] Section 4 (first two sentences)
  153. i = xp.abs(work.f1) < xp.abs(work.f2)
  154. work.xmin = xp.where(i, work.x1, work.x2)
  155. work.fmin = xp.where(i, work.f1, work.f2)
  156. stop = xp.zeros_like(work.x1, dtype=xp.bool) # termination condition met
  157. # If function value tolerance is met, report successful convergence,
  158. # regardless of other conditions. Note that `frtol` has been redefined
  159. # as `frtol = frtol * minimum(f1, f2)`, where `f1` and `f2` are the
  160. # function evaluated at the original ends of the bracket.
  161. i = xp.abs(work.fmin) <= work.fatol + work.frtol
  162. work.status[i] = eim._ECONVERGED
  163. stop[i] = True
  164. # If the bracket is no longer valid, report failure (unless a function
  165. # tolerance is met, as detected above).
  166. i = (xp.sign(work.f1) == xp.sign(work.f2)) & ~stop
  167. NaN = xp.asarray(xp.nan, dtype=work.xmin.dtype)
  168. work.xmin[i], work.fmin[i], work.status[i] = NaN, NaN, eim._ESIGNERR
  169. stop[i] = True
  170. # If the abscissae are non-finite or either function value is NaN,
  171. # report failure.
  172. x_nonfinite = ~(xp.isfinite(work.x1) & xp.isfinite(work.x2))
  173. f_nan = xp.isnan(work.f1) & xp.isnan(work.f2)
  174. i = (x_nonfinite | f_nan) & ~stop
  175. work.xmin[i], work.fmin[i], work.status[i] = NaN, NaN, eim._EVALUEERR
  176. stop[i] = True
  177. # This is the convergence criterion used in bisect. Chandrupatla's
  178. # criterion is equivalent to this except with a factor of 4 on `xrtol`.
  179. work.dx = xp.abs(work.x2 - work.x1)
  180. work.tol = xp.abs(work.xmin) * work.xrtol + work.xatol
  181. i = work.dx < work.tol
  182. work.status[i] = eim._ECONVERGED
  183. stop[i] = True
  184. return stop
  185. def post_termination_check(work):
  186. # [1] Figure 1 (third diamond and boxes / Equation 1)
  187. xi1 = (work.x1 - work.x2) / (work.x3 - work.x2)
  188. with np.errstate(divide='ignore', invalid='ignore'):
  189. phi1 = (work.f1 - work.f2) / (work.f3 - work.f2)
  190. alpha = (work.x3 - work.x1) / (work.x2 - work.x1)
  191. j = ((1 - xp.sqrt(1 - xi1)) < phi1) & (phi1 < xp.sqrt(xi1))
  192. f1j, f2j, f3j, alphaj = work.f1[j], work.f2[j], work.f3[j], alpha[j]
  193. t = xp.full_like(alpha, 0.5)
  194. t[j] = (f1j / (f1j - f2j) * f3j / (f3j - f2j)
  195. - alphaj * f1j / (f3j - f1j) * f2j / (f2j - f3j))
  196. # [1] Figure 1 (last box; see also BASIC in appendix with comment
  197. # "Adjust T Away from the Interval Boundary")
  198. tl = 0.5 * work.tol / work.dx
  199. work.t = xp.clip(t, tl, 1 - tl)
  200. def customize_result(res, shape):
  201. xl, xr, fl, fr = res['xl'], res['xr'], res['fl'], res['fr']
  202. i = res['xl'] < res['xr']
  203. res['xl'] = xp.where(i, xl, xr)
  204. res['xr'] = xp.where(i, xr, xl)
  205. res['fl'] = xp.where(i, fl, fr)
  206. res['fr'] = xp.where(i, fr, fl)
  207. return shape
  208. return eim._loop(work, callback, shape, maxiter, func, args, dtype,
  209. pre_func_eval, post_func_eval, check_termination,
  210. post_termination_check, customize_result, res_work_pairs,
  211. xp=xp)
  212. def _chandrupatla_iv(func, args, xatol, xrtol,
  213. fatol, frtol, maxiter, callback):
  214. # Input validation for `_chandrupatla`
  215. if not callable(func):
  216. raise ValueError('`func` must be callable.')
  217. if not np.iterable(args):
  218. args = (args,)
  219. # tolerances are floats, not arrays; OK to use NumPy
  220. tols = np.asarray([xatol if xatol is not None else 1,
  221. xrtol if xrtol is not None else 1,
  222. fatol if fatol is not None else 1,
  223. frtol if frtol is not None else 1])
  224. if (not np.issubdtype(tols.dtype, np.number) or np.any(tols < 0)
  225. or np.any(np.isnan(tols)) or tols.shape != (4,)):
  226. raise ValueError('Tolerances must be non-negative scalars.')
  227. if maxiter is not None:
  228. maxiter_int = int(maxiter)
  229. if maxiter != maxiter_int or maxiter < 0:
  230. raise ValueError('`maxiter` must be a non-negative integer.')
  231. if callback is not None and not callable(callback):
  232. raise ValueError('`callback` must be callable.')
  233. return func, args, xatol, xrtol, fatol, frtol, maxiter, callback
  234. def _chandrupatla_minimize(func, x1, x2, x3, *, args=(), xatol=None,
  235. xrtol=None, fatol=None, frtol=None, maxiter=100,
  236. callback=None):
  237. """Find the minimizer of an elementwise function.
  238. For each element of the output of `func`, `_chandrupatla_minimize` seeks
  239. the scalar minimizer that minimizes the element. This function allows for
  240. `x1`, `x2`, `x3`, and the elements of `args` to be arrays of any
  241. broadcastable shapes.
  242. Parameters
  243. ----------
  244. func : callable
  245. The function whose minimizer is desired. The signature must be::
  246. func(x: ndarray, *args) -> ndarray
  247. where each element of ``x`` is a finite real and ``args`` is a tuple,
  248. which may contain an arbitrary number of arrays that are broadcastable
  249. with `x`. ``func`` must be an elementwise function: each element
  250. ``func(x)[i]`` must equal ``func(x[i])`` for all indices ``i``.
  251. `_chandrupatla` seeks an array ``x`` such that ``func(x)`` is an array
  252. of minima.
  253. x1, x2, x3 : array_like
  254. The abscissae of a standard scalar minimization bracket. A bracket is
  255. valid if ``x1 < x2 < x3`` and ``func(x1) > func(x2) <= func(x3)``.
  256. Must be broadcastable with one another and `args`.
  257. args : tuple, optional
  258. Additional positional arguments to be passed to `func`. Must be arrays
  259. broadcastable with `x1`, `x2`, and `x3`. If the callable to be
  260. differentiated requires arguments that are not broadcastable with `x`,
  261. wrap that callable with `func` such that `func` accepts only `x` and
  262. broadcastable arrays.
  263. xatol, xrtol, fatol, frtol : float, optional
  264. Absolute and relative tolerances on the minimizer and function value.
  265. See Notes for details.
  266. maxiter : int, optional
  267. The maximum number of iterations of the algorithm to perform.
  268. callback : callable, optional
  269. An optional user-supplied function to be called before the first
  270. iteration and after each iteration.
  271. Called as ``callback(res)``, where ``res`` is a ``_RichResult``
  272. similar to that returned by `_chandrupatla_minimize` (but containing
  273. the current iterate's values of all variables). If `callback` raises a
  274. ``StopIteration``, the algorithm will terminate immediately and
  275. `_chandrupatla_minimize` will return a result.
  276. Returns
  277. -------
  278. res : _RichResult
  279. An instance of `scipy._lib._util._RichResult` with the following
  280. attributes. (The descriptions are written as though the values will be
  281. scalars; however, if `func` returns an array, the outputs will be
  282. arrays of the same shape.)
  283. success : bool
  284. ``True`` when the algorithm terminated successfully (status ``0``).
  285. status : int
  286. An integer representing the exit status of the algorithm.
  287. ``0`` : The algorithm converged to the specified tolerances.
  288. ``-1`` : The algorithm encountered an invalid bracket.
  289. ``-2`` : The maximum number of iterations was reached.
  290. ``-3`` : A non-finite value was encountered.
  291. ``-4`` : Iteration was terminated by `callback`.
  292. ``1`` : The algorithm is proceeding normally (in `callback` only).
  293. x : float
  294. The minimizer of the function, if the algorithm terminated
  295. successfully.
  296. fun : float
  297. The value of `func` evaluated at `x`.
  298. nfev : int
  299. The number of points at which `func` was evaluated.
  300. nit : int
  301. The number of iterations of the algorithm that were performed.
  302. xl, xm, xr : float
  303. The final three-point bracket.
  304. fl, fm, fr : float
  305. The function value at the bracket points.
  306. Notes
  307. -----
  308. Implemented based on Chandrupatla's original paper [1]_.
  309. If ``x1 < x2 < x3`` are the points of the bracket and ``f1 > f2 <= f3``
  310. are the values of ``func`` at those points, then the algorithm is
  311. considered to have converged when ``x3 - x1 <= abs(x2)*xrtol + xatol``
  312. or ``(f1 - 2*f2 + f3)/2 <= abs(f2)*frtol + fatol``. Note that first of
  313. these differs from the termination conditions described in [1]_. The
  314. default values of `xrtol` is the square root of the precision of the
  315. appropriate dtype, and ``xatol = fatol = frtol`` is the smallest normal
  316. number of the appropriate dtype.
  317. References
  318. ----------
  319. .. [1] Chandrupatla, Tirupathi R. (1998).
  320. "An efficient quadratic fit-sectioning algorithm for minimization
  321. without derivatives".
  322. Computer Methods in Applied Mechanics and Engineering, 152 (1-2),
  323. 211-217. https://doi.org/10.1016/S0045-7825(97)00190-4
  324. See Also
  325. --------
  326. golden, brent, bounded
  327. Examples
  328. --------
  329. >>> from scipy.optimize._chandrupatla import _chandrupatla_minimize
  330. >>> def f(x, args=1):
  331. ... return (x - args)**2
  332. >>> res = _chandrupatla_minimize(f, -5, 0, 5)
  333. >>> res.x
  334. 1.0
  335. >>> c = [1, 1.5, 2]
  336. >>> res = _chandrupatla_minimize(f, -5, 0, 5, args=(c,))
  337. >>> res.x
  338. array([1. , 1.5, 2. ])
  339. """
  340. res = _chandrupatla_iv(func, args, xatol, xrtol,
  341. fatol, frtol, maxiter, callback)
  342. func, args, xatol, xrtol, fatol, frtol, maxiter, callback = res
  343. # Initialization
  344. xs = (x1, x2, x3)
  345. temp = eim._initialize(func, xs, args)
  346. func, xs, fs, args, shape, dtype, xp = temp # line split for PEP8
  347. x1, x2, x3 = xs
  348. f1, f2, f3 = fs
  349. phi = xp.asarray(0.5 + 0.5*5**0.5, dtype=dtype)[()] # golden ratio
  350. status = xp.full_like(x1, eim._EINPROGRESS, dtype=xp.int32) # in progress
  351. nit, nfev = 0, 3 # three function evaluations performed above
  352. fatol = xp.finfo(dtype).smallest_normal if fatol is None else fatol
  353. frtol = xp.finfo(dtype).smallest_normal if frtol is None else frtol
  354. xatol = xp.finfo(dtype).smallest_normal if xatol is None else xatol
  355. xrtol = math.sqrt(xp.finfo(dtype).eps) if xrtol is None else xrtol
  356. # Ensure that x1 < x2 < x3 initially.
  357. xs, fs = xp.stack((x1, x2, x3)), xp.stack((f1, f2, f3))
  358. i = xp.argsort(xs, axis=0)
  359. x1, x2, x3 = xp.take_along_axis(xs, i, axis=0) # data-apis/array-api#808
  360. f1, f2, f3 = xp.take_along_axis(fs, i, axis=0) # data-apis/array-api#808
  361. q0 = xp_copy(x3) # "At the start, q0 is set at x3..." ([1] after (7))
  362. work = _RichResult(x1=x1, f1=f1, x2=x2, f2=f2, x3=x3, f3=f3, phi=phi,
  363. xatol=xatol, xrtol=xrtol, fatol=fatol, frtol=frtol,
  364. nit=nit, nfev=nfev, status=status, q0=q0, args=args)
  365. res_work_pairs = [('status', 'status'),
  366. ('x', 'x2'), ('fun', 'f2'),
  367. ('nit', 'nit'), ('nfev', 'nfev'),
  368. ('xl', 'x1'), ('xm', 'x2'), ('xr', 'x3'),
  369. ('fl', 'f1'), ('fm', 'f2'), ('fr', 'f3')]
  370. def pre_func_eval(work):
  371. # `_check_termination` is called first -> `x3 - x2 > x2 - x1`
  372. # But let's calculate a few terms that we'll reuse
  373. x21 = work.x2 - work.x1
  374. x32 = work.x3 - work.x2
  375. # [1] Section 3. "The quadratic minimum point Q1 is calculated using
  376. # the relations developed in the previous section." [1] Section 2 (5/6)
  377. A = x21 * (work.f3 - work.f2)
  378. B = x32 * (work.f1 - work.f2)
  379. C = A / (A + B)
  380. # q1 = C * (work.x1 + work.x2) / 2 + (1 - C) * (work.x2 + work.x3) / 2
  381. q1 = 0.5 * (C*(work.x1 - work.x3) + work.x2 + work.x3) # much faster
  382. # this is an array, so multiplying by 0.5 does not change dtype
  383. # "If Q1 and Q0 are sufficiently close... Q1 is accepted if it is
  384. # sufficiently away from the inside point x2"
  385. i = xp.abs(q1 - work.q0) < 0.5 * xp.abs(x21) # [1] (7)
  386. xi = q1[i]
  387. # Later, after (9), "If the point Q1 is in a +/- xtol neighborhood of
  388. # x2, the new point is chosen in the larger interval at a distance
  389. # tol away from x2."
  390. # See also QBASIC code after "Accept Ql adjust if close to X2".
  391. j = xp.abs(q1[i] - work.x2[i]) <= work.xtol[i]
  392. xi[j] = work.x2[i][j] + xp.sign(x32[i][j]) * work.xtol[i][j]
  393. # "If condition (7) is not satisfied, golden sectioning of the larger
  394. # interval is carried out to introduce the new point."
  395. # (For simplicity, we go ahead and calculate it for all points, but we
  396. # change the elements for which the condition was satisfied.)
  397. x = work.x2 + (2 - work.phi) * x32
  398. x[i] = xi
  399. # "We define Q0 as the value of Q1 at the previous iteration."
  400. work.q0 = q1
  401. return x
  402. def post_func_eval(x, f, work):
  403. # Standard logic for updating a three-point bracket based on a new
  404. # point. In QBASIC code, see "IF SGN(X-X2) = SGN(X3-X2) THEN...".
  405. # There is an awful lot of data copying going on here; this would
  406. # probably benefit from code optimization or implementation in Pythran.
  407. i = xp.sign(x - work.x2) == xp.sign(work.x3 - work.x2)
  408. xi, x1i, x2i, x3i = x[i], work.x1[i], work.x2[i], work.x3[i],
  409. fi, f1i, f2i, f3i = f[i], work.f1[i], work.f2[i], work.f3[i]
  410. j = fi > f2i
  411. x3i[j], f3i[j] = xi[j], fi[j]
  412. j = ~j
  413. x1i[j], f1i[j], x2i[j], f2i[j] = x2i[j], f2i[j], xi[j], fi[j]
  414. ni = ~i
  415. xni, x1ni, x2ni, x3ni = x[ni], work.x1[ni], work.x2[ni], work.x3[ni],
  416. fni, f1ni, f2ni, f3ni = f[ni], work.f1[ni], work.f2[ni], work.f3[ni]
  417. j = fni > f2ni
  418. x1ni[j], f1ni[j] = xni[j], fni[j]
  419. j = ~j
  420. x3ni[j], f3ni[j], x2ni[j], f2ni[j] = x2ni[j], f2ni[j], xni[j], fni[j]
  421. work.x1[i], work.x2[i], work.x3[i] = x1i, x2i, x3i
  422. work.f1[i], work.f2[i], work.f3[i] = f1i, f2i, f3i
  423. work.x1[ni], work.x2[ni], work.x3[ni] = x1ni, x2ni, x3ni,
  424. work.f1[ni], work.f2[ni], work.f3[ni] = f1ni, f2ni, f3ni
  425. def check_termination(work):
  426. # Check for all terminal conditions and record statuses.
  427. stop = xp.zeros_like(work.x1, dtype=bool) # termination condition met
  428. # Bracket is invalid; stop and don't return minimizer/minimum
  429. i = ((work.f2 > work.f1) | (work.f2 > work.f3))
  430. work.x2[i], work.f2[i] = xp.nan, xp.nan
  431. stop[i], work.status[i] = True, eim._ESIGNERR
  432. # Non-finite values; stop and don't return minimizer/minimum
  433. finite = xp.isfinite(work.x1+work.x2+work.x3+work.f1+work.f2+work.f3)
  434. i = ~(finite | stop)
  435. work.x2[i], work.f2[i] = xp.nan, xp.nan
  436. stop[i], work.status[i] = True, eim._EVALUEERR
  437. # [1] Section 3 "Points 1 and 3 are interchanged if necessary to make
  438. # the (x2, x3) the larger interval."
  439. # Note: I had used np.choose; this is much faster. This would be a good
  440. # place to save e.g. `work.x3 - work.x2` for reuse, but I tried and
  441. # didn't notice a speed boost, so let's keep it simple.
  442. i = xp.abs(work.x3 - work.x2) < xp.abs(work.x2 - work.x1)
  443. temp = work.x1[i]
  444. work.x1[i] = work.x3[i]
  445. work.x3[i] = temp
  446. temp = work.f1[i]
  447. work.f1[i] = work.f3[i]
  448. work.f3[i] = temp
  449. # [1] Section 3 (bottom of page 212)
  450. # "We set a tolerance value xtol..."
  451. work.xtol = xp.abs(work.x2) * work.xrtol + work.xatol # [1] (8)
  452. # "The convergence based on interval is achieved when..."
  453. # Note: Equality allowed in case of `xtol=0`
  454. i = xp.abs(work.x3 - work.x2) <= 2 * work.xtol # [1] (9)
  455. # "We define ftol using..."
  456. ftol = xp.abs(work.f2) * work.frtol + work.fatol # [1] (10)
  457. # "The convergence based on function values is achieved when..."
  458. # Note 1: modify in place to incorporate tolerance on function value.
  459. # Note 2: factor of 2 is not in the text; see QBASIC start of DO loop
  460. i |= (work.f1 - 2 * work.f2 + work.f3) <= 2*ftol # [1] (11)
  461. i &= ~stop
  462. stop[i], work.status[i] = True, eim._ECONVERGED
  463. return stop
  464. def post_termination_check(work):
  465. pass
  466. def customize_result(res, shape):
  467. xl, xr, fl, fr = res['xl'], res['xr'], res['fl'], res['fr']
  468. i = res['xl'] >= res['xr']
  469. res['xl'] = xp.where(i, xr, xl)
  470. res['xr'] = xp.where(i, xl, xr)
  471. res['fl'] = xp.where(i, fr, fl)
  472. res['fr'] = xp.where(i, fl, fr)
  473. return shape
  474. return eim._loop(work, callback, shape, maxiter, func, args, dtype,
  475. pre_func_eval, post_func_eval, check_termination,
  476. post_termination_check, customize_result, res_work_pairs,
  477. xp=xp)