_bracket.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706
  1. import numpy as np
  2. import scipy._lib._elementwise_iterative_method as eim
  3. from scipy._lib._util import _RichResult
  4. from scipy._lib._array_api import array_namespace, xp_ravel, xp_promote
  5. _ELIMITS = -1 # used in _bracket_root
  6. _ESTOPONESIDE = 2 # used in _bracket_root
  7. def _bracket_root_iv(func, xl0, xr0, xmin, xmax, factor, args, maxiter):
  8. if not callable(func):
  9. raise ValueError('`func` must be callable.')
  10. if not np.iterable(args):
  11. args = (args,)
  12. xp = array_namespace(xl0, xr0, xmin, xmax, factor, *args)
  13. # If xr0 is not supplied, fill with a dummy value for the sake of
  14. # broadcasting. We need to wait until xmax has been validated to
  15. # compute the default value.
  16. xr0_not_supplied = False
  17. if xr0 is None:
  18. xr0 = xp.nan
  19. xr0_not_supplied = True
  20. xmin = -xp.inf if xmin is None else xmin
  21. xmax = xp.inf if xmax is None else xmax
  22. factor = 2. if factor is None else factor
  23. xl0, xr0, xmin, xmax, factor = xp_promote(
  24. xl0, xr0, xmin, xmax, factor, broadcast=True, force_floating=True, xp=xp)
  25. if not xp.isdtype(xl0.dtype, ('integral', 'real floating')):
  26. raise ValueError('`xl0` must be numeric and real.')
  27. if (not xp.isdtype(xr0.dtype, "numeric")
  28. or xp.isdtype(xr0.dtype, "complex floating")):
  29. raise ValueError('`xr0` must be numeric and real.')
  30. if (not xp.isdtype(xmin.dtype, "numeric")
  31. or xp.isdtype(xmin.dtype, "complex floating")):
  32. raise ValueError('`xmin` must be numeric and real.')
  33. if (not xp.isdtype(xmax.dtype, "numeric")
  34. or xp.isdtype(xmax.dtype, "complex floating")):
  35. raise ValueError('`xmax` must be numeric and real.')
  36. if (not xp.isdtype(factor.dtype, "numeric")
  37. or xp.isdtype(factor.dtype, "complex floating")):
  38. raise ValueError('`factor` must be numeric and real.')
  39. if not xp.all(factor > 1):
  40. raise ValueError('All elements of `factor` must be greater than 1.')
  41. # Calculate the default value of xr0 if a value has not been supplied.
  42. # Be careful to ensure xr0 is not larger than xmax.
  43. if xr0_not_supplied:
  44. xr0 = xl0 + xp.minimum((xmax - xl0)/ 8, 1.0)
  45. xr0 = xp.astype(xr0, xl0.dtype, copy=False)
  46. maxiter = xp.asarray(maxiter)
  47. message = '`maxiter` must be a non-negative integer.'
  48. if (not xp.isdtype(maxiter.dtype, "numeric") or maxiter.shape != tuple()
  49. or xp.isdtype(maxiter.dtype, "complex floating")):
  50. raise ValueError(message)
  51. maxiter_int = int(maxiter[()])
  52. if not maxiter == maxiter_int or maxiter < 0:
  53. raise ValueError(message)
  54. return func, xl0, xr0, xmin, xmax, factor, args, maxiter, xp
  55. def _bracket_root(func, xl0, xr0=None, *, xmin=None, xmax=None, factor=None,
  56. args=(), maxiter=1000):
  57. """Bracket the root of a monotonic scalar function of one variable
  58. This function works elementwise when `xl0`, `xr0`, `xmin`, `xmax`, `factor`, and
  59. the elements of `args` are broadcastable arrays.
  60. Parameters
  61. ----------
  62. func : callable
  63. The function for which the root is to be bracketed.
  64. The signature must be::
  65. func(x: ndarray, *args) -> ndarray
  66. where each element of ``x`` is a finite real and ``args`` is a tuple,
  67. which may contain an arbitrary number of arrays that are broadcastable
  68. with `x`. ``func`` must be an elementwise function: each element
  69. ``func(x)[i]`` must equal ``func(x[i])`` for all indices ``i``.
  70. xl0, xr0: float array_like
  71. Starting guess of bracket, which need not contain a root. If `xr0` is
  72. not provided, ``xr0 = xl0 + 1``. Must be broadcastable with one another.
  73. xmin, xmax : float array_like, optional
  74. Minimum and maximum allowable endpoints of the bracket, inclusive. Must
  75. be broadcastable with `xl0` and `xr0`.
  76. factor : float array_like, default: 2
  77. The factor used to grow the bracket. See notes for details.
  78. args : tuple, optional
  79. Additional positional arguments to be passed to `func`. Must be arrays
  80. broadcastable with `xl0`, `xr0`, `xmin`, and `xmax`. If the callable to be
  81. bracketed requires arguments that are not broadcastable with these
  82. arrays, wrap that callable with `func` such that `func` accepts
  83. only `x` and broadcastable arrays.
  84. maxiter : int, optional
  85. The maximum number of iterations of the algorithm to perform.
  86. Returns
  87. -------
  88. res : _RichResult
  89. An instance of `scipy._lib._util._RichResult` with the following
  90. attributes. The descriptions are written as though the values will be
  91. scalars; however, if `func` returns an array, the outputs will be
  92. arrays of the same shape.
  93. xl, xr : float
  94. The lower and upper ends of the bracket, if the algorithm
  95. terminated successfully.
  96. fl, fr : float
  97. The function value at the lower and upper ends of the bracket.
  98. nfev : int
  99. The number of function evaluations required to find the bracket.
  100. This is distinct from the number of times `func` is *called*
  101. because the function may evaluated at multiple points in a single
  102. call.
  103. nit : int
  104. The number of iterations of the algorithm that were performed.
  105. status : int
  106. An integer representing the exit status of the algorithm.
  107. - ``0`` : The algorithm produced a valid bracket.
  108. - ``-1`` : The bracket expanded to the allowable limits without finding a bracket.
  109. - ``-2`` : The maximum number of iterations was reached.
  110. - ``-3`` : A non-finite value was encountered.
  111. - ``-4`` : Iteration was terminated by `callback`.
  112. - ``-5``: The initial bracket does not satisfy `xmin <= xl0 < xr0 < xmax`.
  113. - ``1`` : The algorithm is proceeding normally (in `callback` only).
  114. - ``2`` : A bracket was found in the opposite search direction (in `callback` only).
  115. success : bool
  116. ``True`` when the algorithm terminated successfully (status ``0``).
  117. Notes
  118. -----
  119. This function generalizes an algorithm found in pieces throughout
  120. `scipy.stats`. The strategy is to iteratively grow the bracket ``(l, r)``
  121. until ``func(l) < 0 < func(r)``. The bracket grows to the left as follows.
  122. - If `xmin` is not provided, the distance between `xl0` and `l` is iteratively
  123. increased by `factor`.
  124. - If `xmin` is provided, the distance between `xmin` and `l` is iteratively
  125. decreased by `factor`. Note that this also *increases* the bracket size.
  126. Growth of the bracket to the right is analogous.
  127. Growth of the bracket in one direction stops when the endpoint is no longer
  128. finite, the function value at the endpoint is no longer finite, or the
  129. endpoint reaches its limiting value (`xmin` or `xmax`). Iteration terminates
  130. when the bracket stops growing in both directions, the bracket surrounds
  131. the root, or a root is found (accidentally).
  132. If two brackets are found - that is, a bracket is found on both sides in
  133. the same iteration, the smaller of the two is returned.
  134. If roots of the function are found, both `l` and `r` are set to the
  135. leftmost root.
  136. """ # noqa: E501
  137. # Todo:
  138. # - find bracket with sign change in specified direction
  139. # - Add tolerance
  140. # - allow factor < 1?
  141. callback = None # works; I just don't want to test it
  142. temp = _bracket_root_iv(func, xl0, xr0, xmin, xmax, factor, args, maxiter)
  143. func, xl0, xr0, xmin, xmax, factor, args, maxiter, xp = temp
  144. xs = (xl0, xr0)
  145. temp = eim._initialize(func, xs, args)
  146. func, xs, fs, args, shape, dtype, xp = temp # line split for PEP8
  147. xl0, xr0 = xs
  148. xmin = xp_ravel(xp.astype(xp.broadcast_to(xmin, shape), dtype, copy=False), xp=xp)
  149. xmax = xp_ravel(xp.astype(xp.broadcast_to(xmax, shape), dtype, copy=False), xp=xp)
  150. invalid_bracket = ~((xmin <= xl0) & (xl0 < xr0) & (xr0 <= xmax))
  151. # The approach is to treat the left and right searches as though they were
  152. # (almost) totally independent one-sided bracket searches. (The interaction
  153. # is considered when checking for termination and preparing the result
  154. # object.)
  155. # `x` is the "moving" end of the bracket
  156. x = xp.concat(xs)
  157. f = xp.concat(fs)
  158. invalid_bracket = xp.concat((invalid_bracket, invalid_bracket))
  159. n = x.shape[0] // 2
  160. # `x_last` is the previous location of the moving end of the bracket. If
  161. # the signs of `f` and `f_last` are different, `x` and `x_last` form a
  162. # bracket.
  163. x_last = xp.concat((x[n:], x[:n]))
  164. f_last = xp.concat((f[n:], f[:n]))
  165. # `x0` is the "fixed" end of the bracket.
  166. x0 = x_last
  167. # We don't need to retain the corresponding function value, since the
  168. # fixed end of the bracket is only needed to compute the new value of the
  169. # moving end; it is never returned.
  170. limit = xp.concat((xmin, xmax))
  171. factor = xp_ravel(xp.broadcast_to(factor, shape), xp=xp)
  172. factor = xp.astype(factor, dtype, copy=False)
  173. factor = xp.concat((factor, factor))
  174. active = xp.arange(2*n)
  175. args = [xp.concat((arg, arg)) for arg in args]
  176. # This is needed due to inner workings of `eim._loop`.
  177. # We're abusing it a tiny bit.
  178. shape = shape + (2,)
  179. # `d` is for "distance".
  180. # For searches without a limit, the distance between the fixed end of the
  181. # bracket `x0` and the moving end `x` will grow by `factor` each iteration.
  182. # For searches with a limit, the distance between the `limit` and moving
  183. # end of the bracket `x` will shrink by `factor` each iteration.
  184. i = xp.isinf(limit)
  185. ni = ~i
  186. d = xp.zeros_like(x)
  187. d[i] = x[i] - x0[i]
  188. d[ni] = limit[ni] - x[ni]
  189. status = xp.full_like(x, eim._EINPROGRESS, dtype=xp.int32) # in progress
  190. status[invalid_bracket] = eim._EINPUTERR
  191. nit, nfev = 0, 1 # one function evaluation per side performed above
  192. work = _RichResult(x=x, x0=x0, f=f, limit=limit, factor=factor,
  193. active=active, d=d, x_last=x_last, f_last=f_last,
  194. nit=nit, nfev=nfev, status=status, args=args,
  195. xl=xp.nan, xr=xp.nan, fl=xp.nan, fr=xp.nan, n=n)
  196. res_work_pairs = [('status', 'status'), ('xl', 'xl'), ('xr', 'xr'),
  197. ('nit', 'nit'), ('nfev', 'nfev'), ('fl', 'fl'),
  198. ('fr', 'fr'), ('x', 'x'), ('f', 'f'),
  199. ('x_last', 'x_last'), ('f_last', 'f_last')]
  200. def pre_func_eval(work):
  201. # Initialize moving end of bracket
  202. x = xp.zeros_like(work.x)
  203. # Unlimited brackets grow by `factor` by increasing distance from fixed
  204. # end to moving end.
  205. i = xp.isinf(work.limit) # indices of unlimited brackets
  206. work.d[i] *= work.factor[i]
  207. x[i] = work.x0[i] + work.d[i]
  208. # Limited brackets grow by decreasing the distance from the limit to
  209. # the moving end.
  210. ni = ~i # indices of limited brackets
  211. work.d[ni] /= work.factor[ni]
  212. x[ni] = work.limit[ni] - work.d[ni]
  213. return x
  214. def post_func_eval(x, f, work):
  215. # Keep track of the previous location of the moving end so that we can
  216. # return a narrower bracket. (The alternative is to remember the
  217. # original fixed end, but then the bracket would be wider than needed.)
  218. work.x_last = work.x
  219. work.f_last = work.f
  220. work.x = x
  221. work.f = f
  222. def check_termination(work):
  223. # Condition 0: initial bracket is invalid
  224. stop = (work.status == eim._EINPUTERR)
  225. # Condition 1: a valid bracket (or the root itself) has been found
  226. sf = xp.sign(work.f)
  227. sf_last = xp.sign(work.f_last)
  228. i = ((sf_last == -sf) | (sf_last == 0) | (sf == 0)) & ~stop
  229. work.status[i] = eim._ECONVERGED
  230. stop[i] = True
  231. # Condition 2: the other side's search found a valid bracket.
  232. # (If we just found a bracket with the rightward search, we can stop
  233. # the leftward search, and vice-versa.)
  234. # To do this, we need to set the status of the other side's search;
  235. # this is tricky because `work.status` contains only the *active*
  236. # elements, so we don't immediately know the index of the element we
  237. # need to set - or even if it's still there. (That search may have
  238. # terminated already, e.g. by reaching its `limit`.)
  239. # To facilitate this, `work.active` contains a unit integer index of
  240. # each search. Index `k` (`k < n)` and `k + n` correspond with a
  241. # leftward and rightward search, respectively. Elements are removed
  242. # from `work.active` just as they are removed from `work.status`, so
  243. # we use `work.active` to help find the right location in
  244. # `work.status`.
  245. # Get the integer indices of the elements that can also stop
  246. also_stop = (work.active[i] + work.n) % (2*work.n)
  247. # Check whether they are still active. We want to find the indices
  248. # in work.active where the associated values in work.active are
  249. # contained in also_stop. xp.searchsorted let's us take advantage
  250. # of work.active being sorted, but requires some hackery because
  251. # searchsorted solves the separate but related problem of finding
  252. # the indices where the values in also_stop should be added to
  253. # maintain sorted order.
  254. j = xp.searchsorted(work.active, also_stop)
  255. # If the location exceeds the length of the `work.active`, they are
  256. # not there. This happens when a value in also_stop is larger than
  257. # the greatest value in work.active. This case needs special handling
  258. # because we cannot simply check that also_stop == work.active[j].
  259. mask = j < work.active.shape[0]
  260. # Note that we also have to use the mask to filter also_stop to ensure
  261. # that also_stop and j will still have the same shape.
  262. j, also_stop = j[mask], also_stop[mask]
  263. j = j[also_stop == work.active[j]]
  264. # Now convert these to boolean indices to use with `work.status`.
  265. i = xp.zeros_like(stop)
  266. i[j] = True # boolean indices of elements that can also stop
  267. i = i & ~stop
  268. work.status[i] = _ESTOPONESIDE
  269. stop[i] = True
  270. # Condition 3: moving end of bracket reaches limit
  271. i = (work.x == work.limit) & ~stop
  272. work.status[i] = _ELIMITS
  273. stop[i] = True
  274. # Condition 4: non-finite value encountered
  275. i = ~(xp.isfinite(work.x) & xp.isfinite(work.f)) & ~stop
  276. work.status[i] = eim._EVALUEERR
  277. stop[i] = True
  278. return stop
  279. def post_termination_check(work):
  280. pass
  281. def customize_result(res, shape):
  282. n = res['x'].shape[0] // 2
  283. # To avoid ambiguity, below we refer to `xl0`, the initial left endpoint
  284. # as `a` and `xr0`, the initial right endpoint, as `b`.
  285. # Because we treat the two one-sided searches as though they were
  286. # independent, what we keep track of in `work` and what we want to
  287. # return in `res` look quite different. Combine the results from the
  288. # two one-sided searches before reporting the results to the user.
  289. # - "a" refers to the leftward search (the moving end started at `a`)
  290. # - "b" refers to the rightward search (the moving end started at `b`)
  291. # - "l" refers to the left end of the bracket (closer to -oo)
  292. # - "r" refers to the right end of the bracket (closer to +oo)
  293. xal = res['x'][:n]
  294. xar = res['x_last'][:n]
  295. xbl = res['x_last'][n:]
  296. xbr = res['x'][n:]
  297. fal = res['f'][:n]
  298. far = res['f_last'][:n]
  299. fbl = res['f_last'][n:]
  300. fbr = res['f'][n:]
  301. # Initialize the brackets and corresponding function values to return
  302. # to the user. Brackets may not be valid (e.g. there is no root,
  303. # there weren't enough iterations, NaN encountered), but we still need
  304. # to return something. One option would be all NaNs, but what I've
  305. # chosen here is the left- and right-most points at which the function
  306. # has been evaluated. This gives the user some information about what
  307. # interval of the real line has been searched and shows that there is
  308. # no sign change between the two ends.
  309. xl = xp.asarray(xal, copy=True)
  310. fl = xp.asarray(fal, copy=True)
  311. xr = xp.asarray(xbr, copy=True)
  312. fr = xp.asarray(fbr, copy=True)
  313. # `status` indicates whether the bracket is valid or not. If so,
  314. # we want to adjust the bracket we return to be the narrowest possible
  315. # given the points at which we evaluated the function.
  316. # For example if bracket "a" is valid and smaller than bracket "b" OR
  317. # if bracket "a" is valid and bracket "b" is not valid, we want to
  318. # return bracket "a" (and vice versa).
  319. sa = res['status'][:n]
  320. sb = res['status'][n:]
  321. da = xar - xal
  322. db = xbr - xbl
  323. i1 = ((da <= db) & (sa == 0)) | ((sa == 0) & (sb != 0))
  324. i2 = ((db <= da) & (sb == 0)) | ((sb == 0) & (sa != 0))
  325. xr[i1] = xar[i1]
  326. fr[i1] = far[i1]
  327. xl[i2] = xbl[i2]
  328. fl[i2] = fbl[i2]
  329. # Finish assembling the result object
  330. res['xl'] = xl
  331. res['xr'] = xr
  332. res['fl'] = fl
  333. res['fr'] = fr
  334. res['nit'] = xp.maximum(res['nit'][:n], res['nit'][n:])
  335. res['nfev'] = res['nfev'][:n] + res['nfev'][n:]
  336. # If the status on one side is zero, the status is zero. In any case,
  337. # report the status from one side only.
  338. res['status'] = xp.where(sa == 0, sa, sb)
  339. res['success'] = (res['status'] == 0)
  340. del res['x']
  341. del res['f']
  342. del res['x_last']
  343. del res['f_last']
  344. return shape[:-1]
  345. return eim._loop(work, callback, shape, maxiter, func, args, dtype,
  346. pre_func_eval, post_func_eval, check_termination,
  347. post_termination_check, customize_result, res_work_pairs,
  348. xp)
  349. def _bracket_minimum_iv(func, xm0, xl0, xr0, xmin, xmax, factor, args, maxiter):
  350. if not callable(func):
  351. raise ValueError('`func` must be callable.')
  352. if not np.iterable(args):
  353. args = (args,)
  354. xp = array_namespace(xm0, xl0, xr0, xmin, xmax, factor, *args)
  355. xmin = -xp.inf if xmin is None else xmin
  356. xmax = xp.inf if xmax is None else xmax
  357. # If xl0 (xr0) is not supplied, fill with a dummy value for the sake
  358. # of broadcasting. We need to wait until xmin (xmax) has been validated
  359. # to compute the default values.
  360. xl0_not_supplied = False
  361. if xl0 is None:
  362. xl0 = xp.nan
  363. xl0_not_supplied = True
  364. xr0_not_supplied = False
  365. if xr0 is None:
  366. xr0 = xp.nan
  367. xr0_not_supplied = True
  368. factor = 2.0 if factor is None else factor
  369. xm0, xl0, xr0, xmin, xmax, factor = xp_promote(
  370. xm0, xl0, xr0, xmin, xmax, factor, broadcast=True, force_floating=True, xp=xp)
  371. if not xp.isdtype(xm0.dtype, ('integral', 'real floating')):
  372. raise ValueError('`xm0` must be numeric and real.')
  373. if (not xp.isdtype(xl0.dtype, "numeric")
  374. or xp.isdtype(xl0.dtype, "complex floating")):
  375. raise ValueError('`xl0` must be numeric and real.')
  376. if (not xp.isdtype(xr0.dtype, "numeric")
  377. or xp.isdtype(xr0.dtype, "complex floating")):
  378. raise ValueError('`xr0` must be numeric and real.')
  379. if (not xp.isdtype(xmin.dtype, "numeric")
  380. or xp.isdtype(xmin.dtype, "complex floating")):
  381. raise ValueError('`xmin` must be numeric and real.')
  382. if (not xp.isdtype(xmax.dtype, "numeric")
  383. or xp.isdtype(xmax.dtype, "complex floating")):
  384. raise ValueError('`xmax` must be numeric and real.')
  385. if (not xp.isdtype(factor.dtype, "numeric")
  386. or xp.isdtype(factor.dtype, "complex floating")):
  387. raise ValueError('`factor` must be numeric and real.')
  388. if not xp.all(factor > 1):
  389. raise ValueError('All elements of `factor` must be greater than 1.')
  390. # Calculate default values of xl0 and/or xr0 if they have not been supplied
  391. # by the user. We need to be careful to ensure xl0 and xr0 are not outside
  392. # of (xmin, xmax).
  393. if xl0_not_supplied:
  394. xl0 = xm0 - xp.minimum((xm0 - xmin)/16, 0.5)
  395. xl0 = xp.astype(xl0, xm0.dtype, copy=False)
  396. if xr0_not_supplied:
  397. xr0 = xm0 + xp.minimum((xmax - xm0)/16, 0.5)
  398. xr0 = xp.astype(xr0, xm0.dtype, copy=False)
  399. maxiter = xp.asarray(maxiter)
  400. message = '`maxiter` must be a non-negative integer.'
  401. if (not xp.isdtype(maxiter.dtype, "numeric") or maxiter.shape != tuple()
  402. or xp.isdtype(maxiter.dtype, "complex floating")):
  403. raise ValueError(message)
  404. maxiter_int = int(maxiter[()])
  405. if not maxiter == maxiter_int or maxiter < 0:
  406. raise ValueError(message)
  407. return func, xm0, xl0, xr0, xmin, xmax, factor, args, maxiter, xp
  408. def _bracket_minimum(func, xm0, *, xl0=None, xr0=None, xmin=None, xmax=None,
  409. factor=None, args=(), maxiter=1000):
  410. """Bracket the minimum of a unimodal scalar function of one variable
  411. This function works elementwise when `xm0`, `xl0`, `xr0`, `xmin`, `xmax`,
  412. and the elements of `args` are broadcastable arrays.
  413. Parameters
  414. ----------
  415. func : callable
  416. The function for which the minimum is to be bracketed.
  417. The signature must be::
  418. func(x: ndarray, *args) -> ndarray
  419. where each element of ``x`` is a finite real and ``args`` is a tuple,
  420. which may contain an arbitrary number of arrays that are broadcastable
  421. with ``x``. `func` must be an elementwise function: each element
  422. ``func(x)[i]`` must equal ``func(x[i])`` for all indices `i`.
  423. xm0: float array_like
  424. Starting guess for middle point of bracket.
  425. xl0, xr0: float array_like, optional
  426. Starting guesses for left and right endpoints of the bracket. Must be
  427. broadcastable with one another and with `xm0`.
  428. xmin, xmax : float array_like, optional
  429. Minimum and maximum allowable endpoints of the bracket, inclusive. Must
  430. be broadcastable with `xl0`, `xm0`, and `xr0`.
  431. factor : float array_like, optional
  432. Controls expansion of bracket endpoint in downhill direction. Works
  433. differently in the cases where a limit is set in the downhill direction
  434. with `xmax` or `xmin`. See Notes.
  435. args : tuple, optional
  436. Additional positional arguments to be passed to `func`. Must be arrays
  437. broadcastable with `xl0`, `xm0`, `xr0`, `xmin`, and `xmax`. If the
  438. callable to be bracketed requires arguments that are not broadcastable
  439. with these arrays, wrap that callable with `func` such that `func`
  440. accepts only ``x`` and broadcastable arrays.
  441. maxiter : int, optional
  442. The maximum number of iterations of the algorithm to perform. The number
  443. of function evaluations is three greater than the number of iterations.
  444. Returns
  445. -------
  446. res : _RichResult
  447. An instance of `scipy._lib._util._RichResult` with the following
  448. attributes. The descriptions are written as though the values will be
  449. scalars; however, if `func` returns an array, the outputs will be
  450. arrays of the same shape.
  451. xl, xm, xr : float
  452. The left, middle, and right points of the bracket, if the algorithm
  453. terminated successfully.
  454. fl, fm, fr : float
  455. The function value at the left, middle, and right points of the bracket.
  456. nfev : int
  457. The number of function evaluations required to find the bracket.
  458. nit : int
  459. The number of iterations of the algorithm that were performed.
  460. status : int
  461. An integer representing the exit status of the algorithm.
  462. - ``0`` : The algorithm produced a valid bracket.
  463. - ``-1`` : The bracket expanded to the allowable limits. Assuming
  464. unimodality, this implies the endpoint at the limit is a
  465. minimizer.
  466. - ``-2`` : The maximum number of iterations was reached.
  467. - ``-3`` : A non-finite value was encountered.
  468. - ``-4`` : ``None`` shall pass.
  469. - ``-5`` : The initial bracket does not satisfy
  470. `xmin <= xl0 < xm0 < xr0 <= xmax`.
  471. success : bool
  472. ``True`` when the algorithm terminated successfully (status ``0``).
  473. Notes
  474. -----
  475. Similar to `scipy.optimize.bracket`, this function seeks to find real
  476. points ``xl < xm < xr`` such that ``f(xl) >= f(xm)`` and ``f(xr) >= f(xm)``,
  477. where at least one of the inequalities is strict. Unlike `scipy.optimize.bracket`,
  478. this function can operate in a vectorized manner on array input, so long as
  479. the input arrays are broadcastable with each other. Also unlike
  480. `scipy.optimize.bracket`, users may specify minimum and maximum endpoints
  481. for the desired bracket.
  482. Given an initial trio of points ``xl = xl0``, ``xm = xm0``, ``xr = xr0``,
  483. the algorithm checks if these points already give a valid bracket. If not,
  484. a new endpoint, ``w`` is chosen in the "downhill" direction, ``xm`` becomes the new
  485. opposite endpoint, and either `xl` or `xr` becomes the new middle point,
  486. depending on which direction is downhill. The algorithm repeats from here.
  487. The new endpoint `w` is chosen differently depending on whether or not a
  488. boundary `xmin` or `xmax` has been set in the downhill direction. Without
  489. loss of generality, suppose the downhill direction is to the right, so that
  490. ``f(xl) > f(xm) > f(xr)``. If there is no boundary to the right, then `w`
  491. is chosen to be ``xr + factor * (xr - xm)`` where `factor` is controlled by
  492. the user (defaults to 2.0) so that step sizes increase in geometric proportion.
  493. If there is a boundary, `xmax` in this case, then `w` is chosen to be
  494. ``xmax - (xmax - xr)/factor``, with steps slowing to a stop at
  495. `xmax`. This cautious approach ensures that a minimum near but distinct from
  496. the boundary isn't missed while also detecting whether or not the `xmax` is
  497. a minimizer when `xmax` is reached after a finite number of steps.
  498. """ # noqa: E501
  499. callback = None # works; I just don't want to test it
  500. temp = _bracket_minimum_iv(func, xm0, xl0, xr0, xmin, xmax, factor, args, maxiter)
  501. func, xm0, xl0, xr0, xmin, xmax, factor, args, maxiter, xp = temp
  502. xs = (xl0, xm0, xr0)
  503. temp = eim._initialize(func, xs, args)
  504. func, xs, fs, args, shape, dtype, xp = temp
  505. xl0, xm0, xr0 = xs
  506. fl0, fm0, fr0 = fs
  507. xmin = xp.astype(xp.broadcast_to(xmin, shape), dtype, copy=False)
  508. xmin = xp_ravel(xmin, xp=xp)
  509. xmax = xp.astype(xp.broadcast_to(xmax, shape), dtype, copy=False)
  510. xmax = xp_ravel(xmax, xp=xp)
  511. invalid_bracket = ~((xmin <= xl0) & (xl0 < xm0) & (xm0 < xr0) & (xr0 <= xmax))
  512. # We will modify factor later on so make a copy. np.broadcast_to returns
  513. # a read-only view.
  514. factor = xp.astype(xp.broadcast_to(factor, shape), dtype, copy=True)
  515. factor = xp_ravel(factor)
  516. # To simplify the logic, swap xl and xr if f(xl) < f(xr). We should always be
  517. # marching downhill in the direction from xl to xr.
  518. comp = fl0 < fr0
  519. xl0[comp], xr0[comp] = xr0[comp], xl0[comp]
  520. fl0[comp], fr0[comp] = fr0[comp], fl0[comp]
  521. # We only need the boundary in the direction we're traveling.
  522. limit = xp.where(comp, xmin, xmax)
  523. unlimited = xp.isinf(limit)
  524. limited = ~unlimited
  525. step = xp.empty_like(xl0)
  526. step[unlimited] = (xr0[unlimited] - xm0[unlimited])
  527. step[limited] = (limit[limited] - xr0[limited])
  528. # Step size is divided by factor for case where there is a limit.
  529. factor[limited] = 1 / factor[limited]
  530. status = xp.full_like(xl0, eim._EINPROGRESS, dtype=xp.int32)
  531. status[invalid_bracket] = eim._EINPUTERR
  532. nit, nfev = 0, 3
  533. work = _RichResult(xl=xl0, xm=xm0, xr=xr0, xr0=xr0, fl=fl0, fm=fm0, fr=fr0,
  534. step=step, limit=limit, limited=limited, factor=factor, nit=nit,
  535. nfev=nfev, status=status, args=args)
  536. res_work_pairs = [('status', 'status'), ('xl', 'xl'), ('xm', 'xm'), ('xr', 'xr'),
  537. ('nit', 'nit'), ('nfev', 'nfev'), ('fl', 'fl'), ('fm', 'fm'),
  538. ('fr', 'fr')]
  539. def pre_func_eval(work):
  540. work.step *= work.factor
  541. x = xp.empty_like(work.xr)
  542. x[~work.limited] = work.xr0[~work.limited] + work.step[~work.limited]
  543. x[work.limited] = work.limit[work.limited] - work.step[work.limited]
  544. # Since the new bracket endpoint is calculated from an offset with the
  545. # limit, it may be the case that the new endpoint equals the old endpoint,
  546. # when the old endpoint is sufficiently close to the limit. We use the
  547. # limit itself as the new endpoint in these cases.
  548. x[work.limited] = xp.where(
  549. x[work.limited] == work.xr[work.limited],
  550. work.limit[work.limited],
  551. x[work.limited],
  552. )
  553. return x
  554. def post_func_eval(x, f, work):
  555. work.xl, work.xm, work.xr = work.xm, work.xr, x
  556. work.fl, work.fm, work.fr = work.fm, work.fr, f
  557. def check_termination(work):
  558. # Condition 0: Initial bracket is invalid.
  559. stop = (work.status == eim._EINPUTERR)
  560. # Condition 1: A valid bracket has been found.
  561. i = (
  562. (work.fl >= work.fm) & (work.fr > work.fm)
  563. | (work.fl > work.fm) & (work.fr >= work.fm)
  564. ) & ~stop
  565. work.status[i] = eim._ECONVERGED
  566. stop[i] = True
  567. # Condition 2: Moving end of bracket reaches limit.
  568. i = (work.xr == work.limit) & ~stop
  569. work.status[i] = _ELIMITS
  570. stop[i] = True
  571. # Condition 3: non-finite value encountered
  572. i = ~(xp.isfinite(work.xr) & xp.isfinite(work.fr)) & ~stop
  573. work.status[i] = eim._EVALUEERR
  574. stop[i] = True
  575. return stop
  576. def post_termination_check(work):
  577. pass
  578. def customize_result(res, shape):
  579. # Reorder entries of xl and xr if they were swapped due to f(xl0) < f(xr0).
  580. comp = res['xl'] > res['xr']
  581. res['xl'][comp], res['xr'][comp] = res['xr'][comp], res['xl'][comp]
  582. res['fl'][comp], res['fr'][comp] = res['fr'][comp], res['fl'][comp]
  583. return shape
  584. return eim._loop(work, callback, shape,
  585. maxiter, func, args, dtype,
  586. pre_func_eval, post_func_eval,
  587. check_termination, post_termination_check,
  588. customize_result, res_work_pairs, xp)