test_bracket.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888
  1. import pytest
  2. import numpy as np
  3. from scipy.optimize._bracket import _ELIMITS
  4. from scipy.optimize.elementwise import bracket_root, bracket_minimum
  5. import scipy._lib._elementwise_iterative_method as eim
  6. from scipy import stats
  7. from scipy._lib._array_api_no_0d import (xp_assert_close, xp_assert_equal,
  8. xp_assert_less)
  9. from scipy._lib._array_api import xp_ravel, make_xp_test_case
  10. # These tests were originally written for the private `optimize._bracket`
  11. # interfaces, but now we want the tests to check the behavior of the public
  12. # `optimize.elementwise` interfaces. Therefore, rather than importing
  13. # `_bracket_root`/`_bracket_minimum` from `_bracket.py`, we import
  14. # `bracket_root`/`bracket_minimum` from `optimize.elementwise` and wrap those
  15. # functions to conform to the private interface. This may look a little strange,
  16. # since it effectively just inverts the interface transformation done within the
  17. # `bracket_root`/`bracket_minimum` functions, but it allows us to run the original,
  18. # unmodified tests on the public interfaces, simplifying the PR that adds
  19. # the public interfaces. We'll refactor this when we want to @parametrize the
  20. # tests over multiple `method`s.
  21. def _bracket_root(*args, **kwargs):
  22. res = bracket_root(*args, **kwargs)
  23. res.xl, res.xr = res.bracket
  24. res.fl, res.fr = res.f_bracket
  25. del res.bracket
  26. del res.f_bracket
  27. return res
  28. def _bracket_minimum(*args, **kwargs):
  29. res = bracket_minimum(*args, **kwargs)
  30. res.xl, res.xm, res.xr = res.bracket
  31. res.fl, res.fm, res.fr = res.f_bracket
  32. del res.bracket
  33. del res.f_bracket
  34. return res
  35. @make_xp_test_case(bracket_root)
  36. class TestBracketRoot:
  37. @pytest.mark.parametrize("seed", (615655101, 3141866013, 238075752))
  38. @pytest.mark.parametrize("use_xmin", (False, True))
  39. @pytest.mark.parametrize("other_side", (False, True))
  40. @pytest.mark.parametrize("fix_one_side", (False, True))
  41. def test_nfev_expected(self, seed, use_xmin, other_side, fix_one_side, xp):
  42. # Property-based test to confirm that _bracket_root is behaving as
  43. # expected. The basic case is when root < a < b.
  44. # The number of times bracket expands (per side) can be found by
  45. # setting the expression for the left endpoint of the bracket to the
  46. # root of f (x=0), solving for i, and rounding up. The corresponding
  47. # lower and upper ends of the bracket are found by plugging this back
  48. # into the expression for the ends of the bracket.
  49. # `other_side=True` is the case that a < b < root
  50. # Special cases like a < root < b are tested separately
  51. rng = np.random.default_rng(seed)
  52. xl0, d, factor = xp.asarray(rng.random(size=3) * [1e5, 10, 5])
  53. factor = 1 + factor # factor must be greater than 1
  54. xr0 = xl0 + d # xr0 must be greater than a in basic case
  55. def f(x):
  56. f.count += 1
  57. return x # root is 0
  58. if use_xmin:
  59. xmin = xp.asarray(-rng.random())
  60. n = xp.ceil(xp.log(-(xl0 - xmin) / xmin) / xp.log(factor))
  61. l, u = xmin + (xl0 - xmin)*factor**-n, xmin + (xl0 - xmin)*factor**-(n - 1)
  62. kwargs = dict(xl0=xl0, xr0=xr0, factor=factor, xmin=xmin)
  63. else:
  64. n = xp.ceil(xp.log(xr0/d) / xp.log(factor))
  65. l, u = xr0 - d*factor**n, xr0 - d*factor**(n-1)
  66. kwargs = dict(xl0=xl0, xr0=xr0, factor=factor)
  67. if other_side:
  68. kwargs['xl0'], kwargs['xr0'] = -kwargs['xr0'], -kwargs['xl0']
  69. l, u = -u, -l
  70. if 'xmin' in kwargs:
  71. kwargs['xmax'] = -kwargs.pop('xmin')
  72. if fix_one_side:
  73. if other_side:
  74. kwargs['xmin'] = -xr0
  75. else:
  76. kwargs['xmax'] = xr0
  77. f.count = 0
  78. res = _bracket_root(f, **kwargs)
  79. # Compare reported number of function evaluations `nfev` against
  80. # reported `nit`, actual function call count `f.count`, and theoretical
  81. # number of expansions `n`.
  82. # When both sides are free, these get multiplied by 2 because function
  83. # is evaluated on the left and the right each iteration.
  84. # When one side is fixed, however, we add one: on the right side, the
  85. # function gets evaluated once at b.
  86. # Add 1 to `n` and `res.nit` because function evaluations occur at
  87. # iterations *0*, 1, ..., `n`. Subtract 1 from `f.count` because
  88. # function is called separately for left and right in iteration 0.
  89. if not fix_one_side:
  90. assert res.nfev == 2*(res.nit+1) == 2*(f.count-1) == 2*(n + 1)
  91. else:
  92. assert res.nfev == (res.nit+1)+1 == (f.count-1)+1 == (n+1)+1
  93. # Compare reported bracket to theoretical bracket and reported function
  94. # values to function evaluated at bracket.
  95. bracket = xp.asarray([res.xl, res.xr])
  96. xp_assert_close(bracket, xp.asarray([l, u]))
  97. f_bracket = xp.asarray([res.fl, res.fr])
  98. xp_assert_close(f_bracket, f(bracket))
  99. # Check that bracket is valid and that status and success are correct
  100. assert res.xr > res.xl
  101. signs = xp.sign(f_bracket)
  102. assert signs[0] == -signs[1]
  103. assert res.status == 0
  104. assert res.success
  105. def f(self, q, p):
  106. return stats._stats_py._SimpleNormal().cdf(q) - p
  107. @pytest.mark.parametrize('p', [0.6, np.linspace(0.05, 0.95, 10)])
  108. @pytest.mark.parametrize('xmin', [-5, None])
  109. @pytest.mark.parametrize('xmax', [5, None])
  110. @pytest.mark.parametrize('factor', [1.2, 2])
  111. def test_basic(self, p, xmin, xmax, factor, xp):
  112. # Test basic functionality to bracket root (distribution PPF)
  113. res = _bracket_root(self.f, xp.asarray(-0.01), 0.01, xmin=xmin, xmax=xmax,
  114. factor=factor, args=(xp.asarray(p),))
  115. xp_assert_equal(-xp.sign(res.fl), xp.sign(res.fr))
  116. @pytest.mark.parametrize('shape', [tuple(), (12,), (3, 4), (3, 2, 2)])
  117. def test_vectorization(self, shape, xp):
  118. # Test for correct functionality, output shapes, and dtypes for various
  119. # input shapes.
  120. p = np.linspace(-0.05, 1.05, 12).reshape(shape) if shape else np.float64(0.6)
  121. args = (p,)
  122. maxiter = 10
  123. @np.vectorize
  124. def bracket_root_single(xl0, xr0, xmin, xmax, factor, p):
  125. return _bracket_root(self.f, xl0, xr0, xmin=xmin, xmax=xmax,
  126. factor=factor, args=(p,),
  127. maxiter=maxiter)
  128. def f(*args, **kwargs):
  129. f.f_evals += 1
  130. return self.f(*args, **kwargs)
  131. f.f_evals = 0
  132. rng = np.random.default_rng(2348234)
  133. xl0 = -rng.random(size=shape)
  134. xr0 = rng.random(size=shape)
  135. xmin, xmax = 1e3*xl0, 1e3*xr0
  136. if shape: # make some elements un
  137. i = rng.random(size=shape) > 0.5
  138. xmin[i], xmax[i] = -np.inf, np.inf
  139. factor = rng.random(size=shape) + 1.5
  140. refs = bracket_root_single(xl0, xr0, xmin, xmax, factor, p).ravel()
  141. xl0, xr0, xmin, xmax, factor = (xp.asarray(xl0), xp.asarray(xr0),
  142. xp.asarray(xmin), xp.asarray(xmax),
  143. xp.asarray(factor))
  144. args = tuple(map(xp.asarray, args))
  145. res = _bracket_root(f, xl0, xr0, xmin=xmin, xmax=xmax, factor=factor,
  146. args=args, maxiter=maxiter)
  147. attrs = ['xl', 'xr', 'fl', 'fr', 'success', 'nfev', 'nit']
  148. for attr in attrs:
  149. ref_attr = [xp.asarray(getattr(ref, attr)) for ref in refs]
  150. res_attr = getattr(res, attr)
  151. xp_assert_close(xp_ravel(res_attr, xp=xp), xp.stack(ref_attr))
  152. assert res_attr.shape == shape
  153. assert res.success.dtype == xp.bool
  154. if shape:
  155. assert xp.all(res.success[1:-1])
  156. assert res.status.dtype == xp.int32
  157. assert res.nfev.dtype == xp.int32
  158. assert res.nit.dtype == xp.int32
  159. assert xp.max(res.nit) == f.f_evals - 2
  160. xp_assert_less(res.xl, res.xr)
  161. xp_assert_close(res.fl, xp.asarray(self.f(res.xl, *args)))
  162. xp_assert_close(res.fr, xp.asarray(self.f(res.xr, *args)))
  163. def test_flags(self, xp):
  164. # Test cases that should produce different status flags; show that all
  165. # can be produced simultaneously.
  166. def f(xs, js):
  167. funcs = [lambda x: x - 1.5,
  168. lambda x: x - 1000,
  169. lambda x: x - 1000,
  170. lambda x: x * xp.nan,
  171. lambda x: x]
  172. return [funcs[int(j)](x) for x, j in zip(xs, js)]
  173. args = (xp.arange(5, dtype=xp.int64),)
  174. res = _bracket_root(f,
  175. xl0=xp.asarray([-1., -1., -1., -1., 4.]),
  176. xr0=xp.asarray([1, 1, 1, 1, -4]),
  177. xmin=xp.asarray([-xp.inf, -1, -xp.inf, -xp.inf, 6]),
  178. xmax=xp.asarray([xp.inf, 1, xp.inf, xp.inf, 2]),
  179. args=args, maxiter=3)
  180. ref_flags = xp.asarray([eim._ECONVERGED,
  181. _ELIMITS,
  182. eim._ECONVERR,
  183. eim._EVALUEERR,
  184. eim._EINPUTERR],
  185. dtype=xp.int32)
  186. xp_assert_equal(res.status, ref_flags)
  187. @pytest.mark.parametrize("root", (0.622, [0.622, 0.623]))
  188. @pytest.mark.parametrize('xmin', [-5, None])
  189. @pytest.mark.parametrize('xmax', [5, None])
  190. @pytest.mark.parametrize("dtype", ("float16", "float32", "float64"))
  191. def test_dtype(self, root, xmin, xmax, dtype, xp):
  192. # Test that dtypes are preserved
  193. dtype = getattr(xp, dtype)
  194. xmin = xmin if xmin is None else xp.asarray(xmin, dtype=dtype)
  195. xmax = xmax if xmax is None else xp.asarray(xmax, dtype=dtype)
  196. root = xp.asarray(root, dtype=dtype)
  197. def f(x, root):
  198. return xp.astype((x - root) ** 3, dtype)
  199. bracket = xp.asarray([-0.01, 0.01], dtype=dtype)
  200. res = _bracket_root(f, *bracket, xmin=xmin, xmax=xmax, args=(root,))
  201. assert xp.all(res.success)
  202. assert res.xl.dtype == res.xr.dtype == dtype
  203. assert res.fl.dtype == res.fr.dtype == dtype
  204. def test_input_validation(self, xp):
  205. # Test input validation for appropriate error messages
  206. message = '`func` must be callable.'
  207. with pytest.raises(ValueError, match=message):
  208. _bracket_root(None, -4, 4)
  209. message = '...must be numeric and real.'
  210. with pytest.raises(ValueError, match=message):
  211. _bracket_root(lambda x: x, -4+1j, 4)
  212. with pytest.raises(ValueError, match=message):
  213. _bracket_root(lambda x: x, -4, 4+1j)
  214. with pytest.raises(ValueError, match=message):
  215. _bracket_root(lambda x: x, -4, 4, xmin=4+1j)
  216. with pytest.raises(ValueError, match=message):
  217. _bracket_root(lambda x: x, -4, 4, xmax=4+1j)
  218. with pytest.raises(ValueError, match=message):
  219. _bracket_root(lambda x: x, -4, 4, factor=4+1j)
  220. message = "All elements of `factor` must be greater than 1."
  221. with pytest.raises(ValueError, match=message):
  222. _bracket_root(lambda x: x, -4, 4, factor=0.5)
  223. message = "broadcast"
  224. # raised by `xp.broadcast, but the traceback is readable IMO
  225. with pytest.raises(Exception, match=message):
  226. _bracket_root(lambda x: x, xp.asarray([-2, -3]), xp.asarray([3, 4, 5]))
  227. # Consider making this give a more readable error message
  228. # with pytest.raises(ValueError, match=message):
  229. # _bracket_root(lambda x: [x[0], x[1], x[1]], [-3, -3], [5, 5])
  230. message = '`maxiter` must be a non-negative integer.'
  231. with pytest.raises(ValueError, match=message):
  232. _bracket_root(lambda x: x, -4, 4, maxiter=1.5)
  233. with pytest.raises(ValueError, match=message):
  234. _bracket_root(lambda x: x, -4, 4, maxiter=-1)
  235. with pytest.raises(ValueError, match=message):
  236. _bracket_root(lambda x: x, -4, 4, maxiter="shrubbery")
  237. def test_special_cases(self, xp):
  238. # Test edge cases and other special cases
  239. # Test that integers are not passed to `f`
  240. # (otherwise this would overflow)
  241. def f(x):
  242. assert xp.isdtype(x.dtype, "real floating")
  243. return x ** 99 - 1
  244. res = _bracket_root(f, xp.asarray(-7.), xp.asarray(5.))
  245. assert res.success
  246. # Test maxiter = 0. Should do nothing to bracket.
  247. def f(x):
  248. return x - 10
  249. bracket = (xp.asarray(-3.), xp.asarray(5.))
  250. res = _bracket_root(f, *bracket, maxiter=0)
  251. assert res.xl, res.xr == bracket
  252. assert res.nit == 0
  253. assert res.nfev == 2
  254. assert res.status == -2
  255. # Test scalar `args` (not in tuple)
  256. def f(x, c):
  257. return c*x - 1
  258. res = _bracket_root(f, xp.asarray(-1.), xp.asarray(1.),
  259. args=xp.asarray(3.))
  260. assert res.success
  261. xp_assert_close(res.fl, f(res.xl, 3))
  262. # Test other edge cases
  263. def f(x):
  264. f.count += 1
  265. return x
  266. # 1. root lies within guess of bracket
  267. f.count = 0
  268. _bracket_root(f, xp.asarray(-10), xp.asarray(20))
  269. assert f.count == 2
  270. # 2. bracket endpoint hits root exactly
  271. f.count = 0
  272. res = _bracket_root(f, xp.asarray(5.), xp.asarray(10.),
  273. factor=2)
  274. assert res.nfev == 4
  275. xp_assert_close(res.xl, xp.asarray(0.), atol=1e-15)
  276. xp_assert_close(res.xr, xp.asarray(5.), atol=1e-15)
  277. # 3. bracket limit hits root exactly
  278. with np.errstate(over='ignore'):
  279. res = _bracket_root(f, xp.asarray(5.), xp.asarray(10.),
  280. xmin=0)
  281. xp_assert_close(res.xl, xp.asarray(0.), atol=1e-15)
  282. with np.errstate(over='ignore'):
  283. res = _bracket_root(f, xp.asarray(-10.), xp.asarray(-5.),
  284. xmax=0)
  285. xp_assert_close(res.xr, xp.asarray(0.), atol=1e-15)
  286. # 4. bracket not within min, max
  287. with np.errstate(over='ignore'):
  288. res = _bracket_root(f, xp.asarray(5.), xp.asarray(10.),
  289. xmin=1)
  290. assert not res.success
  291. def test_bug_fixes(self):
  292. # 1. Bug in double sided bracket search.
  293. # Happened in some cases where there are terminations on one side
  294. # after corresponding searches on other side failed due to reaching the
  295. # boundary.
  296. # https://github.com/scipy/scipy/pull/22560#discussion_r1962853839
  297. def f(x, p):
  298. return np.exp(x) - p
  299. p = np.asarray([0.29, 0.35])
  300. res = _bracket_root(f, xl0=-1, xmin=-np.inf, xmax=0, args=(p, ))
  301. # https://github.com/scipy/scipy/pull/22560/files#r1962952517
  302. def f(x, p, c):
  303. return np.exp(x*c) - p
  304. p = [0.32061201, 0.39175242, 0.40047535, 0.50527218, 0.55654373,
  305. 0.11911647, 0.37507896, 0.66554191]
  306. c = [1., -1., 1., 1., -1., 1., 1., 1.]
  307. xl0 = [-7.63108551, 3.27840947, -8.36968526, -1.78124372,
  308. 0.92201295, -2.48930123, -0.66733533, -0.44606749]
  309. xr0 = [-6.63108551, 4.27840947, -7.36968526, -0.78124372,
  310. 1.92201295, -1.48930123, 0., 0.]
  311. xmin = [-np.inf, 0., -np.inf, -np.inf, 0., -np.inf, -np.inf,
  312. -np.inf]
  313. xmax = [0., np.inf, 0., 0., np.inf, 0., 0., 0.]
  314. res = _bracket_root(f, xl0=xl0, xr0=xr0, xmin=xmin, xmax=xmax, args=(p, c))
  315. # 2. Default xl0 + 1 for xr0 exceeds xmax.
  316. # https://github.com/scipy/scipy/pull/22560#discussion_r1962947434
  317. res = _bracket_root(lambda x: x + 0.25, xl0=-0.5, xmin=-np.inf, xmax=0)
  318. assert res.success
  319. @make_xp_test_case(bracket_minimum)
  320. class TestBracketMinimum:
  321. def init_f(self):
  322. def f(x, a, b):
  323. f.count += 1
  324. return (x - a)**2 + b
  325. f.count = 0
  326. return f
  327. def assert_valid_bracket(self, result, xp):
  328. assert xp.all(
  329. (result.xl < result.xm) & (result.xm < result.xr)
  330. )
  331. assert xp.all(
  332. (result.fl >= result.fm) & (result.fr > result.fm)
  333. | (result.fl > result.fm) & (result.fr > result.fm)
  334. )
  335. def get_kwargs(
  336. self, *, xl0=None, xr0=None, factor=None, xmin=None, xmax=None, args=None
  337. ):
  338. names = ("xl0", "xr0", "xmin", "xmax", "factor", "args")
  339. return {
  340. name: val for name, val in zip(names, (xl0, xr0, xmin, xmax, factor, args))
  341. if val is not None
  342. }
  343. @pytest.mark.parametrize(
  344. "seed",
  345. (
  346. 307448016549685229886351382450158984917,
  347. 11650702770735516532954347931959000479,
  348. 113767103358505514764278732330028568336,
  349. )
  350. )
  351. @pytest.mark.parametrize("use_xmin", (False, True))
  352. @pytest.mark.parametrize("other_side", (False, True))
  353. def test_nfev_expected(self, seed, use_xmin, other_side, xp):
  354. rng = np.random.default_rng(seed)
  355. args = (xp.asarray(0.), xp.asarray(0.)) # f(x) = x^2 with minimum at 0
  356. # xl0, xm0, xr0 are chosen such that the initial bracket is to
  357. # the right of the minimum, and the bracket will expand
  358. # downhill towards zero.
  359. xl0, d1, d2, factor = xp.asarray(rng.random(size=4) * [1e5, 10, 10, 5])
  360. xm0 = xl0 + d1
  361. xr0 = xm0 + d2
  362. # Factor should be greater than one.
  363. factor += 1
  364. if use_xmin:
  365. xmin = xp.asarray(-rng.random() * 5, dtype=xp.float64)
  366. n = int(xp.ceil(xp.log(-(xl0 - xmin) / xmin) / xp.log(factor)))
  367. lower = xmin + (xl0 - xmin)*factor**-n
  368. middle = xmin + (xl0 - xmin)*factor**-(n-1)
  369. upper = xmin + (xl0 - xmin)*factor**-(n-2) if n > 1 else xm0
  370. # It may be the case the lower is below the minimum, but we still
  371. # don't have a valid bracket.
  372. if middle**2 > lower**2:
  373. n += 1
  374. lower, middle, upper = (
  375. xmin + (xl0 - xmin)*factor**-n, lower, middle
  376. )
  377. else:
  378. xmin = None
  379. n = int(xp.ceil(xp.log(xl0 / d1) / xp.log(factor)))
  380. lower = xl0 - d1*factor**n
  381. middle = xl0 - d1*factor**(n-1) if n > 1 else xl0
  382. upper = xl0 - d1*factor**(n-2) if n > 1 else xm0
  383. # It may be the case the lower is below the minimum, but we still
  384. # don't have a valid bracket.
  385. if middle**2 > lower**2:
  386. n += 1
  387. lower, middle, upper = (
  388. xl0 - d1*factor**n, lower, middle
  389. )
  390. f = self.init_f()
  391. xmax = None
  392. if other_side:
  393. xl0, xm0, xr0 = -xr0, -xm0, -xl0
  394. xmin, xmax = None, -xmin if xmin is not None else None
  395. lower, middle, upper = -upper, -middle, -lower
  396. kwargs = self.get_kwargs(
  397. xl0=xl0, xr0=xr0, xmin=xmin, xmax=xmax, factor=factor, args=args
  398. )
  399. result = _bracket_minimum(f, xp.asarray(xm0), **kwargs)
  400. # Check that `nfev` and `nit` have the correct relationship
  401. assert result.nfev == result.nit + 3
  402. # Check that `nfev` reports the correct number of function evaluations.
  403. assert result.nfev == f.count
  404. # Check that the number of iterations matches the theoretical value.
  405. assert result.nit == n
  406. # Compare reported bracket to theoretical bracket and reported function
  407. # values to function evaluated at bracket.
  408. xp_assert_close(result.xl, lower)
  409. xp_assert_close(result.xm, middle)
  410. xp_assert_close(result.xr, upper)
  411. xp_assert_close(result.fl, f(lower, *args))
  412. xp_assert_close(result.fm, f(middle, *args))
  413. xp_assert_close(result.fr, f(upper, *args))
  414. self.assert_valid_bracket(result, xp)
  415. assert result.status == 0
  416. assert result.success
  417. def test_flags(self, xp):
  418. # Test cases that should produce different status flags; show that all
  419. # can be produced simultaneously
  420. def f(xs, js):
  421. funcs = [lambda x: (x - 1.5)**2,
  422. lambda x: x,
  423. lambda x: x,
  424. lambda x: xp.asarray(xp.nan),
  425. lambda x: x**2]
  426. return [funcs[int(j)](x) for x, j in zip(xs, js)]
  427. args = (xp.arange(5, dtype=xp.int64),)
  428. xl0 = xp.asarray([-1.0, -1.0, -1.0, -1.0, 6.0])
  429. xm0 = xp.asarray([0.0, 0.0, 0.0, 0.0, 4.0])
  430. xr0 = xp.asarray([1.0, 1.0, 1.0, 1.0, 2.0])
  431. xmin = xp.asarray([-xp.inf, -1.0, -xp.inf, -xp.inf, 8.0])
  432. result = _bracket_minimum(f, xm0, xl0=xl0, xr0=xr0, xmin=xmin,
  433. args=args, maxiter=3)
  434. reference_flags = xp.asarray([eim._ECONVERGED, _ELIMITS,
  435. eim._ECONVERR, eim._EVALUEERR,
  436. eim._EINPUTERR], dtype=xp.int32)
  437. xp_assert_equal(result.status, reference_flags)
  438. @pytest.mark.parametrize("minimum", (0.622, [0.622, 0.623]))
  439. @pytest.mark.parametrize("dtype", ("float16", "float32", "float64"))
  440. @pytest.mark.parametrize("xmin", [-5, None])
  441. @pytest.mark.parametrize("xmax", [5, None])
  442. def test_dtypes(self, minimum, xmin, xmax, dtype, xp):
  443. dtype = getattr(xp, dtype)
  444. xmin = xmin if xmin is None else xp.asarray(xmin, dtype=dtype)
  445. xmax = xmax if xmax is None else xp.asarray(xmax, dtype=dtype)
  446. minimum = xp.asarray(minimum, dtype=dtype)
  447. def f(x, minimum):
  448. return xp.astype((x - minimum)**2, dtype)
  449. xl0, xm0, xr0 = [-0.01, 0.0, 0.01]
  450. result = _bracket_minimum(
  451. f, xp.asarray(xm0, dtype=dtype), xl0=xp.asarray(xl0, dtype=dtype),
  452. xr0=xp.asarray(xr0, dtype=dtype), xmin=xmin, xmax=xmax, args=(minimum, )
  453. )
  454. assert xp.all(result.success)
  455. assert result.xl.dtype == result.xm.dtype == result.xr.dtype == dtype
  456. assert result.fl.dtype == result.fm.dtype == result.fr.dtype == dtype
  457. @pytest.mark.skip_xp_backends(np_only=True, reason="str/object arrays")
  458. def test_input_validation(self, xp):
  459. # Test input validation for appropriate error messages
  460. message = '`func` must be callable.'
  461. with pytest.raises(ValueError, match=message):
  462. _bracket_minimum(None, -4, xl0=4)
  463. message = '...must be numeric and real.'
  464. with pytest.raises(ValueError, match=message):
  465. _bracket_minimum(lambda x: x**2, xp.asarray(4+1j))
  466. with pytest.raises(ValueError, match=message):
  467. _bracket_minimum(lambda x: x**2, xp.asarray(-4), xl0=4+1j)
  468. with pytest.raises(ValueError, match=message):
  469. _bracket_minimum(lambda x: x**2, xp.asarray(-4), xr0=4+1j)
  470. with pytest.raises(ValueError, match=message):
  471. _bracket_minimum(lambda x: x**2, xp.asarray(-4), xmin=4+1j)
  472. with pytest.raises(ValueError, match=message):
  473. _bracket_minimum(lambda x: x**2, xp.asarray(-4), xmax=4+1j)
  474. with pytest.raises(ValueError, match=message):
  475. _bracket_minimum(lambda x: x**2, xp.asarray(-4), factor=4+1j)
  476. message = "All elements of `factor` must be greater than 1."
  477. with pytest.raises(ValueError, match=message):
  478. _bracket_minimum(lambda x: x, xp.asarray(-4), factor=0.5)
  479. message = "Array shapes are incompatible for broadcasting."
  480. with pytest.raises(ValueError, match=message):
  481. _bracket_minimum(lambda x: x**2, xp.asarray([-2, -3]), xl0=[-3, -4, -5])
  482. message = '`maxiter` must be a non-negative integer.'
  483. with pytest.raises(ValueError, match=message):
  484. _bracket_minimum(lambda x: x**2, xp.asarray(-4), xr0=4, maxiter=1.5)
  485. with pytest.raises(ValueError, match=message):
  486. _bracket_minimum(lambda x: x**2, xp.asarray(-4), xr0=4, maxiter=-1)
  487. with pytest.raises(ValueError, match=message):
  488. _bracket_minimum(lambda x: x**2, xp.asarray(-4), xr0=4, maxiter="ekki")
  489. @pytest.mark.parametrize("xl0", [0.0, None])
  490. @pytest.mark.parametrize("xm0", (0.05, 0.1, 0.15))
  491. @pytest.mark.parametrize("xr0", (0.2, 0.4, 0.6, None))
  492. # Minimum is ``a`` for each tuple ``(a, b)`` below. Tests cases where minimum
  493. # is within, or at varying distances to the left or right of the initial
  494. # bracket.
  495. @pytest.mark.parametrize(
  496. "args",
  497. (
  498. (1.2, 0), (-0.5, 0), (0.1, 0), (0.2, 0), (3.6, 0), (21.4, 0),
  499. (121.6, 0), (5764.1, 0), (-6.4, 0), (-12.9, 0), (-146.2, 0)
  500. )
  501. )
  502. def test_scalar_no_limits(self, xl0, xm0, xr0, args, xp):
  503. f = self.init_f()
  504. kwargs = self.get_kwargs(xl0=xl0, xr0=xr0, args=tuple(map(xp.asarray, args)))
  505. result = _bracket_minimum(f, xp.asarray(xm0, dtype=xp.float64), **kwargs)
  506. self.assert_valid_bracket(result, xp)
  507. assert result.status == 0
  508. assert result.success
  509. assert result.nfev == f.count
  510. @pytest.mark.parametrize(
  511. # xmin is set at 0.0 in all cases.
  512. "xl0,xm0,xr0,xmin",
  513. (
  514. # Initial bracket at varying distances from the xmin.
  515. (0.5, 0.75, 1.0, 0.0),
  516. (1.0, 2.5, 4.0, 0.0),
  517. (2.0, 4.0, 6.0, 0.0),
  518. (12.0, 16.0, 20.0, 0.0),
  519. # Test default initial left endpoint selection. It should not
  520. # be below xmin.
  521. (None, 0.75, 1.0, 0.0),
  522. (None, 2.5, 4.0, 0.0),
  523. (None, 4.0, 6.0, 0.0),
  524. (None, 16.0, 20.0, 0.0),
  525. )
  526. )
  527. @pytest.mark.parametrize(
  528. "args", (
  529. (0.0, 0.0), # Minimum is directly at xmin.
  530. (1e-300, 0.0), # Minimum is extremely close to xmin.
  531. (1e-20, 0.0), # Minimum is very close to xmin.
  532. # Minimum at varying distances from xmin.
  533. (0.1, 0.0),
  534. (0.2, 0.0),
  535. (0.4, 0.0)
  536. )
  537. )
  538. def test_scalar_with_limit_left(self, xl0, xm0, xr0, xmin, args, xp):
  539. f = self.init_f()
  540. kwargs = self.get_kwargs(xl0=xl0, xr0=xr0, xmin=xmin,
  541. args=tuple(map(xp.asarray, args)))
  542. result = _bracket_minimum(f, xp.asarray(xm0), **kwargs)
  543. self.assert_valid_bracket(result, xp)
  544. assert result.status == 0
  545. assert result.success
  546. assert result.nfev == f.count
  547. @pytest.mark.parametrize(
  548. #xmax is set to 1.0 in all cases.
  549. "xl0,xm0,xr0,xmax",
  550. (
  551. # Bracket at varying distances from xmax.
  552. (0.2, 0.3, 0.4, 1.0),
  553. (0.05, 0.075, 0.1, 1.0),
  554. (-0.2, -0.1, 0.0, 1.0),
  555. (-21.2, -17.7, -14.2, 1.0),
  556. # Test default right endpoint selection. It should not exceed xmax.
  557. (0.2, 0.3, None, 1.0),
  558. (0.05, 0.075, None, 1.0),
  559. (-0.2, -0.1, None, 1.0),
  560. (-21.2, -17.7, None, 1.0),
  561. )
  562. )
  563. @pytest.mark.parametrize(
  564. "args", (
  565. (0.9999999999999999, 0.0), # Minimum very close to xmax.
  566. # Minimum at varying distances from xmax.
  567. (0.9, 0.0),
  568. (0.7, 0.0),
  569. (0.5, 0.0)
  570. )
  571. )
  572. def test_scalar_with_limit_right(self, xl0, xm0, xr0, xmax, args, xp):
  573. f = self.init_f()
  574. args = tuple(xp.asarray(arg, dtype=xp.float64) for arg in args)
  575. kwargs = self.get_kwargs(xl0=xl0, xr0=xr0, xmax=xmax, args=args)
  576. result = _bracket_minimum(f, xp.asarray(xm0, dtype=xp.float64), **kwargs)
  577. self.assert_valid_bracket(result, xp)
  578. assert result.status == 0
  579. assert result.success
  580. assert result.nfev == f.count
  581. @pytest.mark.parametrize(
  582. "xl0,xm0,xr0,xmin,xmax,args",
  583. (
  584. ( # Case 1:
  585. # Initial bracket.
  586. 0.2,
  587. 0.3,
  588. 0.4,
  589. # Function slopes down to the right from the bracket to a minimum
  590. # at 1.0. xmax is also at 1.0
  591. None,
  592. 1.0,
  593. (1.0, 0.0)
  594. ),
  595. ( # Case 2:
  596. # Initial bracket.
  597. 1.4,
  598. 1.95,
  599. 2.5,
  600. # Function slopes down to the left from the bracket to a minimum at
  601. # 0.3 with xmin set to 0.3.
  602. 0.3,
  603. None,
  604. (0.3, 0.0)
  605. ),
  606. (
  607. # Case 3:
  608. # Initial bracket.
  609. 2.6,
  610. 3.25,
  611. 3.9,
  612. # Function slopes down and to the right to a minimum at 99.4 with xmax
  613. # at 99.4. Tests case where minimum is at xmax relatively further from
  614. # the bracket.
  615. None,
  616. 99.4,
  617. (99.4, 0)
  618. ),
  619. (
  620. # Case 4:
  621. # Initial bracket.
  622. 4,
  623. 4.5,
  624. 5,
  625. # Function slopes down and to the left away from the bracket with a
  626. # minimum at -26.3 with xmin set to -26.3. Tests case where minimum is
  627. # at xmin relatively far from the bracket.
  628. -26.3,
  629. None,
  630. (-26.3, 0)
  631. ),
  632. (
  633. # Case 5:
  634. # Similar to Case 1 above, but tests default values of xl0 and xr0.
  635. None,
  636. 0.3,
  637. None,
  638. None,
  639. 1.0,
  640. (1.0, 0.0)
  641. ),
  642. ( # Case 6:
  643. # Similar to Case 2 above, but tests default values of xl0 and xr0.
  644. None,
  645. 1.95,
  646. None,
  647. 0.3,
  648. None,
  649. (0.3, 0.0)
  650. ),
  651. (
  652. # Case 7:
  653. # Similar to Case 3 above, but tests default values of xl0 and xr0.
  654. None,
  655. 3.25,
  656. None,
  657. None,
  658. 99.4,
  659. (99.4, 0)
  660. ),
  661. (
  662. # Case 8:
  663. # Similar to Case 4 above, but tests default values of xl0 and xr0.
  664. None,
  665. 4.5,
  666. None,
  667. -26.3,
  668. None,
  669. (-26.3, 0)
  670. ),
  671. )
  672. )
  673. def test_minimum_at_boundary_point(self, xl0, xm0, xr0, xmin, xmax, args, xp):
  674. f = self.init_f()
  675. kwargs = self.get_kwargs(xr0=xr0, xmin=xmin, xmax=xmax,
  676. args=tuple(map(xp.asarray, args)))
  677. result = _bracket_minimum(f, xp.asarray(xm0), **kwargs)
  678. assert result.status == -1
  679. assert args[0] in (result.xl, result.xr)
  680. assert result.nfev == f.count
  681. @pytest.mark.parametrize('shape', [tuple(), (12, ), (3, 4), (3, 2, 2)])
  682. def test_vectorization(self, shape, xp):
  683. # Test for correct functionality, output shapes, and dtypes for
  684. # various input shapes.
  685. a = np.linspace(-0.05, 1.05, 12).reshape(shape) if shape else 0.6
  686. args = (a, 0.)
  687. maxiter = 10
  688. @np.vectorize
  689. def bracket_minimum_single(xm0, xl0, xr0, xmin, xmax, factor, a):
  690. return _bracket_minimum(self.init_f(), xm0, xl0=xl0, xr0=xr0, xmin=xmin,
  691. xmax=xmax, factor=factor, maxiter=maxiter,
  692. args=(a, 0.0))
  693. f = self.init_f()
  694. rng = np.random.default_rng(2348234)
  695. xl0 = -rng.random(size=shape)
  696. xr0 = rng.random(size=shape)
  697. xm0 = xl0 + rng.random(size=shape) * (xr0 - xl0)
  698. xmin, xmax = 1e3*xl0, 1e3*xr0
  699. if shape: # make some elements un
  700. i = rng.random(size=shape) > 0.5
  701. xmin[i], xmax[i] = -np.inf, np.inf
  702. factor = rng.random(size=shape) + 1.5
  703. refs = bracket_minimum_single(xm0, xl0, xr0, xmin, xmax, factor, a).ravel()
  704. args = tuple(xp.asarray(arg, dtype=xp.float64) for arg in args)
  705. res = _bracket_minimum(f, xp.asarray(xm0), xl0=xp.asarray(xl0),
  706. xr0=xp.asarray(xr0), xmin=xp.asarray(xmin),
  707. xmax=xp.asarray(xmax), factor=xp.asarray(factor),
  708. args=args, maxiter=maxiter)
  709. attrs = ['xl', 'xm', 'xr', 'fl', 'fm', 'fr', 'success', 'nfev', 'nit']
  710. for attr in attrs:
  711. ref_attr = [xp.asarray(getattr(ref, attr)) for ref in refs]
  712. res_attr = getattr(res, attr)
  713. xp_assert_close(xp_ravel(res_attr, xp=xp), xp.stack(ref_attr))
  714. assert res_attr.shape == shape
  715. assert res.success.dtype == xp.bool
  716. if shape:
  717. assert xp.all(res.success[1:-1])
  718. assert res.status.dtype == xp.int32
  719. assert res.nfev.dtype == xp.int32
  720. assert res.nit.dtype == xp.int32
  721. assert xp.max(res.nit) == f.count - 3
  722. self.assert_valid_bracket(res, xp)
  723. xp_assert_close(res.fl, f(res.xl, *args))
  724. xp_assert_close(res.fm, f(res.xm, *args))
  725. xp_assert_close(res.fr, f(res.xr, *args))
  726. def test_special_cases(self, xp):
  727. # Test edge cases and other special cases.
  728. # Test that integers are not passed to `f`
  729. # (otherwise this would overflow)
  730. def f(x):
  731. assert xp.isdtype(x.dtype, "numeric")
  732. return x ** 98 - 1
  733. result = _bracket_minimum(f, xp.asarray(-7., dtype=xp.float64), xr0=5)
  734. assert result.success
  735. # Test maxiter = 0. Should do nothing to bracket.
  736. def f(x):
  737. return x**2 - 10
  738. xl0, xm0, xr0 = xp.asarray(-3.), xp.asarray(-1.), xp.asarray(2.)
  739. result = _bracket_minimum(f, xm0, xl0=xl0, xr0=xr0, maxiter=0)
  740. xp_assert_equal(result.xl, xl0)
  741. xp_assert_equal(result.xm, xm0)
  742. xp_assert_equal(result.xr, xr0)
  743. # Test scalar `args` (not in tuple)
  744. def f(x, c):
  745. return c*x**2 - 1
  746. result = _bracket_minimum(f, xp.asarray(-1.), args=xp.asarray(3.))
  747. assert result.success
  748. xp_assert_close(result.fl, f(result.xl, 3))
  749. # Initial bracket is valid.
  750. f = self.init_f()
  751. xl0, xm0, xr0 = xp.asarray(-1.0), xp.asarray(-0.2), xp.asarray(1.0)
  752. args = (xp.asarray(0.), xp.asarray(0.))
  753. result = _bracket_minimum(f, xm0, xl0=xl0, xr0=xr0, args=args)
  754. assert f.count == 3
  755. xp_assert_equal(result.xl, xl0)
  756. xp_assert_equal(result.xm , xm0)
  757. xp_assert_equal(result.xr, xr0)
  758. xp_assert_equal(result.fl, f(xl0, *args))
  759. xp_assert_equal(result.fm, f(xm0, *args))
  760. xp_assert_equal(result.fr, f(xr0, *args))
  761. def test_gh_20562_left(self, xp):
  762. # Regression test for https://github.com/scipy/scipy/issues/20562
  763. # minimum of f in [xmin, xmax] is at xmin.
  764. xmin, xmax = xp.asarray(0.21933608), xp.asarray(1.39713606)
  765. def f(x):
  766. log_a, log_b = xp.log(xmin), xp.log(xmax)
  767. return -((log_b - log_a)*x)**-1
  768. result = _bracket_minimum(f, xp.asarray(0.5535723499480897), xmin=xmin,
  769. xmax=xmax)
  770. xp_assert_close(result.xl, xmin)
  771. def test_gh_20562_right(self, xp):
  772. # Regression test for https://github.com/scipy/scipy/issues/20562
  773. # minimum of f in [xmin, xmax] is at xmax.
  774. xmin, xmax = xp.asarray(-1.39713606), xp.asarray(-0.21933608)
  775. def f(x):
  776. log_a, log_b = xp.log(-xmax), xp.log(-xmin)
  777. return ((log_b - log_a)*x)**-1
  778. result = _bracket_minimum(f, xp.asarray(-0.5535723499480897),
  779. xmin=xmin, xmax=xmax)
  780. xp_assert_close(result.xr, xmax)