test_chandrupatla.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976
  1. import math
  2. import pytest
  3. import numpy as np
  4. from copy import deepcopy
  5. from scipy import stats, special
  6. import scipy._lib._elementwise_iterative_method as eim
  7. import scipy._lib.array_api_extra as xpx
  8. from scipy._lib._array_api import (array_namespace, is_cupy, is_numpy, xp_ravel,
  9. xp_size, make_xp_test_case)
  10. from scipy._lib._array_api_no_0d import (xp_assert_close, xp_assert_equal,
  11. xp_assert_less)
  12. from scipy.optimize.elementwise import find_minimum, find_root
  13. from scipy.optimize._tstutils import _CHANDRUPATLA_TESTS
  14. from itertools import permutations
  15. def _vectorize(xp):
  16. # xp-compatible version of np.vectorize
  17. # assumes arguments are all arrays of the same shape
  18. def decorator(f):
  19. def wrapped(*arg_arrays):
  20. shape = arg_arrays[0].shape
  21. arg_arrays = [xp_ravel(arg_array, xp=xp) for arg_array in arg_arrays]
  22. res = []
  23. for i in range(math.prod(shape)):
  24. arg_scalars = [arg_array[i] for arg_array in arg_arrays]
  25. res.append(f(*arg_scalars))
  26. return res
  27. return wrapped
  28. return decorator
  29. # These tests were originally written for the private `optimize._chandrupatla`
  30. # interfaces, but now we want the tests to check the behavior of the public
  31. # `optimize.elementwise` interfaces. Therefore, rather than importing
  32. # `_chandrupatla`/`_chandrupatla_minimize` from `_chandrupatla.py`, we import
  33. # `find_root`/`find_minimum` from `optimize.elementwise` and wrap those
  34. # functions to conform to the private interface. This may look a little strange,
  35. # since it effectively just inverts the interface transformation done within the
  36. # `find_root`/`find_minimum` functions, but it allows us to run the original,
  37. # unmodified tests on the public interfaces, simplifying the PR that adds
  38. # the public interfaces. We'll refactor this when we want to @parametrize the
  39. # tests over multiple `method`s.
  40. def _wrap_chandrupatla(func):
  41. def _chandrupatla_wrapper(f, *bracket, **kwargs):
  42. # avoid passing arguments to `find_minimum` to this function
  43. tol_keys = {'xatol', 'xrtol', 'fatol', 'frtol'}
  44. tolerances = {key: kwargs.pop(key) for key in tol_keys if key in kwargs}
  45. _callback = kwargs.pop('callback', None)
  46. if callable(_callback):
  47. def callback(res):
  48. if func == find_root:
  49. res.xl, res.xr = res.bracket
  50. res.fl, res.fr = res.f_bracket
  51. else:
  52. res.xl, res.xm, res.xr = res.bracket
  53. res.fl, res.fm, res.fr = res.f_bracket
  54. res.fun = res.f_x
  55. del res.bracket
  56. del res.f_bracket
  57. del res.f_x
  58. return _callback(res)
  59. else:
  60. callback = _callback
  61. res = func(f, bracket, tolerances=tolerances, callback=callback, **kwargs)
  62. if func == find_root:
  63. res.xl, res.xr = res.bracket
  64. res.fl, res.fr = res.f_bracket
  65. else:
  66. res.xl, res.xm, res.xr = res.bracket
  67. res.fl, res.fm, res.fr = res.f_bracket
  68. res.fun = res.f_x
  69. del res.bracket
  70. del res.f_bracket
  71. del res.f_x
  72. return res
  73. return _chandrupatla_wrapper
  74. _chandrupatla_minimize = _wrap_chandrupatla(find_minimum)
  75. def f1(x):
  76. return 100*(1 - x**3.)**2 + (1-x**2.) + 2*(1-x)**2.
  77. def f2(x):
  78. return 5 + (x - 2.)**6
  79. def f3(x):
  80. xp = array_namespace(x)
  81. return xp.exp(x) - 5*x
  82. def f4(x):
  83. return x**5. - 5*x**3. - 20.*x + 5.
  84. def f5(x):
  85. return 8*x**3 - 2*x**2 - 7*x + 3
  86. def _bracket_minimum(func, x1, x2):
  87. phi = 1.61803398875
  88. maxiter = 100
  89. f1 = func(x1)
  90. f2 = func(x2)
  91. step = x2 - x1
  92. x1, x2, f1, f2, step = ((x2, x1, f2, f1, -step) if f2 > f1
  93. else (x1, x2, f1, f2, step))
  94. for i in range(maxiter):
  95. step *= phi
  96. x3 = x2 + step
  97. f3 = func(x3)
  98. if f3 < f2:
  99. x1, x2, f1, f2 = x2, x3, f2, f3
  100. else:
  101. break
  102. return x1, x2, x3, f1, f2, f3
  103. cases = [
  104. (f1, -1, 11),
  105. (f1, -2, 13),
  106. (f1, -4, 13),
  107. (f1, -8, 15),
  108. (f1, -16, 16),
  109. (f1, -32, 19),
  110. (f1, -64, 20),
  111. (f1, -128, 21),
  112. (f1, -256, 21),
  113. (f1, -512, 19),
  114. (f1, -1024, 24),
  115. (f2, -1, 8),
  116. (f2, -2, 6),
  117. (f2, -4, 6),
  118. (f2, -8, 7),
  119. (f2, -16, 8),
  120. (f2, -32, 8),
  121. (f2, -64, 9),
  122. (f2, -128, 11),
  123. (f2, -256, 13),
  124. (f2, -512, 12),
  125. (f2, -1024, 13),
  126. (f3, -1, 11),
  127. (f3, -2, 11),
  128. (f3, -4, 11),
  129. (f3, -8, 10),
  130. (f3, -16, 14),
  131. (f3, -32, 12),
  132. (f3, -64, 15),
  133. (f3, -128, 18),
  134. (f3, -256, 18),
  135. (f3, -512, 19),
  136. (f3, -1024, 19),
  137. (f4, -0.05, 9),
  138. (f4, -0.10, 11),
  139. (f4, -0.15, 11),
  140. (f4, -0.20, 11),
  141. (f4, -0.25, 11),
  142. (f4, -0.30, 9),
  143. (f4, -0.35, 9),
  144. (f4, -0.40, 9),
  145. (f4, -0.45, 10),
  146. (f4, -0.50, 10),
  147. (f4, -0.55, 10),
  148. (f5, -0.05, 6),
  149. (f5, -0.10, 7),
  150. (f5, -0.15, 8),
  151. (f5, -0.20, 10),
  152. (f5, -0.25, 9),
  153. (f5, -0.30, 8),
  154. (f5, -0.35, 7),
  155. (f5, -0.40, 7),
  156. (f5, -0.45, 9),
  157. (f5, -0.50, 9),
  158. (f5, -0.55, 8)
  159. ]
  160. @make_xp_test_case(find_minimum)
  161. class TestChandrupatlaMinimize:
  162. def f(self, x, loc):
  163. xp = array_namespace(x, loc)
  164. res = -xp.exp(-1/2 * (x-loc)**2) / (2*xp.pi)**0.5
  165. return xp.asarray(res, dtype=x.dtype)[()]
  166. @pytest.mark.parametrize('dtype', ('float32', 'float64'))
  167. @pytest.mark.parametrize('loc', [0.6, np.linspace(-1.05, 1.05, 10)])
  168. def test_basic(self, loc, xp, dtype):
  169. # Find mode of normal distribution. Compare mode against location
  170. # parameter and value of pdf at mode against expected pdf.
  171. rtol = {'float32': 5e-3, 'float64': 5e-7}[dtype]
  172. dtype = getattr(xp, dtype)
  173. bracket = (xp.asarray(xi, dtype=dtype) for xi in (-5, 0, 5))
  174. loc = xp.asarray(loc, dtype=dtype)
  175. fun = xp.broadcast_to(xp.asarray(-stats.norm.pdf(0), dtype=dtype), loc.shape)
  176. res = _chandrupatla_minimize(self.f, *bracket, args=(loc,))
  177. xp_assert_close(res.x, loc, rtol=rtol)
  178. xp_assert_equal(res.fun, fun)
  179. @pytest.mark.parametrize('shape', [tuple(), (12,), (3, 4), (3, 2, 2)])
  180. def test_vectorization(self, shape, xp):
  181. # Test for correct functionality, output shapes, and dtypes for various
  182. # input shapes.
  183. loc = xp.linspace(-0.05, 1.05, 12).reshape(shape) if shape else xp.asarray(0.6)
  184. args = (loc,)
  185. bracket = xp.asarray(-5.), xp.asarray(0.), xp.asarray(5.)
  186. @_vectorize(xp)
  187. def chandrupatla_single(loc_single):
  188. return _chandrupatla_minimize(self.f, *bracket, args=(loc_single,))
  189. def f(*args, **kwargs):
  190. f.f_evals += 1
  191. return self.f(*args, **kwargs)
  192. f.f_evals = 0
  193. res = _chandrupatla_minimize(f, *bracket, args=args)
  194. refs = chandrupatla_single(loc)
  195. attrs = ['x', 'fun', 'success', 'status', 'nfev', 'nit',
  196. 'xl', 'xm', 'xr', 'fl', 'fm', 'fr']
  197. for attr in attrs:
  198. ref_attr = xp.stack([getattr(ref, attr) for ref in refs])
  199. res_attr = xp_ravel(getattr(res, attr))
  200. xp_assert_equal(res_attr, ref_attr)
  201. assert getattr(res, attr).shape == shape
  202. xp_assert_equal(res.fun, self.f(res.x, *args))
  203. xp_assert_equal(res.fl, self.f(res.xl, *args))
  204. xp_assert_equal(res.fm, self.f(res.xm, *args))
  205. xp_assert_equal(res.fr, self.f(res.xr, *args))
  206. assert xp.max(res.nfev) == f.f_evals
  207. assert xp.max(res.nit) == f.f_evals - 3
  208. assert xp.isdtype(res.success.dtype, 'bool')
  209. assert xp.isdtype(res.status.dtype, 'integral')
  210. assert xp.isdtype(res.nfev.dtype, 'integral')
  211. assert xp.isdtype(res.nit.dtype, 'integral')
  212. def test_flags(self, xp):
  213. # Test cases that should produce different status flags; show that all
  214. # can be produced simultaneously.
  215. def f(xs, js):
  216. funcs = [lambda x: (x - 2.5) ** 2,
  217. lambda x: x - 10,
  218. lambda x: (x - 2.5) ** 4,
  219. lambda x: xp.full_like(x, xp.asarray(xp.nan))]
  220. res = []
  221. for i in range(xp_size(js)):
  222. x = xs[i, ...]
  223. j = int(xp_ravel(js)[i])
  224. res.append(funcs[j](x))
  225. return xp.stack(res)
  226. args = (xp.arange(4, dtype=xp.int64),)
  227. bracket = (xp.asarray([0]*4, dtype=xp.float64),
  228. xp.asarray([2]*4, dtype=xp.float64),
  229. xp.asarray([np.pi]*4, dtype=xp.float64))
  230. res = _chandrupatla_minimize(f, *bracket, args=args, maxiter=10)
  231. ref_flags = xp.asarray([eim._ECONVERGED, eim._ESIGNERR, eim._ECONVERR,
  232. eim._EVALUEERR], dtype=xp.int32)
  233. xp_assert_equal(res.status, ref_flags)
  234. def test_convergence(self, xp):
  235. # Test that the convergence tolerances behave as expected
  236. rng = np.random.default_rng(2585255913088665241)
  237. p = xp.asarray(rng.random(size=3))
  238. bracket = (xp.asarray(-5, dtype=xp.float64), xp.asarray(0), xp.asarray(5))
  239. args = (p,)
  240. kwargs0 = dict(args=args, xatol=0, xrtol=0, fatol=0, frtol=0)
  241. kwargs = kwargs0.copy()
  242. kwargs['xatol'] = 1e-3
  243. res1 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
  244. j1 = xp.abs(res1.xr - res1.xl)
  245. tol = xp.asarray(4*kwargs['xatol'], dtype=p.dtype)
  246. xp_assert_less(j1, xp.full((3,), tol, dtype=p.dtype))
  247. kwargs['xatol'] = 1e-6
  248. res2 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
  249. j2 = xp.abs(res2.xr - res2.xl)
  250. tol = xp.asarray(4*kwargs['xatol'], dtype=p.dtype)
  251. xp_assert_less(j2, xp.full((3,), tol, dtype=p.dtype))
  252. xp_assert_less(j2, j1)
  253. kwargs = kwargs0.copy()
  254. kwargs['xrtol'] = 1e-3
  255. res1 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
  256. j1 = xp.abs(res1.xr - res1.xl)
  257. tol = xp.asarray(4*kwargs['xrtol']*xp.abs(res1.x), dtype=p.dtype)
  258. xp_assert_less(j1, tol)
  259. kwargs['xrtol'] = 1e-6
  260. res2 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
  261. j2 = xp.abs(res2.xr - res2.xl)
  262. tol = xp.asarray(4*kwargs['xrtol']*xp.abs(res2.x), dtype=p.dtype)
  263. xp_assert_less(j2, tol)
  264. xp_assert_less(j2, j1)
  265. kwargs = kwargs0.copy()
  266. kwargs['fatol'] = 1e-3
  267. res1 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
  268. h1 = xp.abs(res1.fl - 2 * res1.fm + res1.fr)
  269. tol = xp.asarray(2*kwargs['fatol'], dtype=p.dtype)
  270. xp_assert_less(h1, xp.full((3,), tol, dtype=p.dtype))
  271. kwargs['fatol'] = 1e-6
  272. res2 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
  273. h2 = xp.abs(res2.fl - 2 * res2.fm + res2.fr)
  274. tol = xp.asarray(2*kwargs['fatol'], dtype=p.dtype)
  275. xp_assert_less(h2, xp.full((3,), tol, dtype=p.dtype))
  276. xp_assert_less(h2, h1)
  277. kwargs = kwargs0.copy()
  278. kwargs['frtol'] = 1e-3
  279. res1 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
  280. h1 = xp.abs(res1.fl - 2 * res1.fm + res1.fr)
  281. tol = xp.asarray(2*kwargs['frtol']*xp.abs(res1.fun), dtype=p.dtype)
  282. xp_assert_less(h1, tol)
  283. kwargs['frtol'] = 1e-6
  284. res2 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
  285. h2 = xp.abs(res2.fl - 2 * res2.fm + res2.fr)
  286. tol = xp.asarray(2*kwargs['frtol']*abs(res2.fun), dtype=p.dtype)
  287. xp_assert_less(h2, tol)
  288. xp_assert_less(h2, h1)
  289. def test_maxiter_callback(self, xp):
  290. # Test behavior of `maxiter` parameter and `callback` interface
  291. loc = xp.asarray(0.612814)
  292. bracket = (xp.asarray(-5), xp.asarray(0), xp.asarray(5))
  293. maxiter = 5
  294. res = _chandrupatla_minimize(self.f, *bracket, args=(loc,),
  295. maxiter=maxiter)
  296. assert not xp.any(res.success)
  297. assert xp.all(res.nfev == maxiter+3)
  298. assert xp.all(res.nit == maxiter)
  299. def callback(res):
  300. callback.iter += 1
  301. callback.res = res
  302. assert hasattr(res, 'x')
  303. if callback.iter == 0:
  304. # callback is called once with initial bracket
  305. assert (res.xl, res.xm, res.xr) == bracket
  306. else:
  307. changed_xr = (res.xl == callback.xl) & (res.xr != callback.xr)
  308. changed_xl = (res.xl != callback.xl) & (res.xr == callback.xr)
  309. assert xp.all(changed_xr | changed_xl)
  310. callback.xl = res.xl
  311. callback.xr = res.xr
  312. assert res.status == eim._EINPROGRESS
  313. xp_assert_equal(self.f(res.xl, loc), res.fl)
  314. xp_assert_equal(self.f(res.xm, loc), res.fm)
  315. xp_assert_equal(self.f(res.xr, loc), res.fr)
  316. xp_assert_equal(self.f(res.x, loc), res.fun)
  317. if callback.iter == maxiter:
  318. raise StopIteration
  319. callback.xl = xp.nan
  320. callback.xr = xp.nan
  321. callback.iter = -1 # callback called once before first iteration
  322. callback.res = None
  323. res2 = _chandrupatla_minimize(self.f, *bracket, args=(loc,),
  324. callback=callback)
  325. # terminating with callback is identical to terminating due to maxiter
  326. # (except for `status`)
  327. for key in res.keys():
  328. if key == 'status':
  329. assert res[key] == eim._ECONVERR
  330. # assert callback.res[key] == eim._EINPROGRESS
  331. assert res2[key] == eim._ECALLBACK
  332. else:
  333. assert res2[key] == callback.res[key] == res[key]
  334. @pytest.mark.parametrize('case', cases)
  335. def test_nit_expected(self, case, xp):
  336. # Test that `_chandrupatla` implements Chandrupatla's algorithm:
  337. # in all 55 test cases, the number of iterations performed
  338. # matches the number reported in the original paper.
  339. func, x1, nit = case
  340. # Find bracket using the algorithm in the paper
  341. step = 0.2
  342. x2 = x1 + step
  343. x1, x2, x3, f1, f2, f3 = _bracket_minimum(func, x1, x2)
  344. # Use tolerances from original paper
  345. xatol = 0.0001
  346. fatol = 0.000001
  347. xrtol = 1e-16
  348. frtol = 1e-16
  349. bracket = xp.asarray(x1), xp.asarray(x2), xp.asarray(x3, dtype=xp.float64)
  350. res = _chandrupatla_minimize(func, *bracket, xatol=xatol,
  351. fatol=fatol, xrtol=xrtol, frtol=frtol)
  352. xp_assert_equal(res.nit, xp.asarray(nit, dtype=xp.int32))
  353. @pytest.mark.parametrize("loc", (0.65, [0.65, 0.7]))
  354. @pytest.mark.parametrize("dtype", ('float16', 'float32', 'float64'))
  355. def test_dtype(self, loc, dtype, xp):
  356. # Test that dtypes are preserved
  357. dtype = getattr(xp, dtype)
  358. loc = xp.asarray(loc, dtype=dtype)
  359. bracket = (xp.asarray(-3, dtype=dtype),
  360. xp.asarray(1, dtype=dtype),
  361. xp.asarray(5, dtype=dtype))
  362. def f(x, loc):
  363. assert x.dtype == dtype
  364. return xp.astype((x - loc)**2, dtype)
  365. res = _chandrupatla_minimize(f, *bracket, args=(loc,))
  366. assert res.x.dtype == dtype
  367. xp_assert_close(res.x, loc, rtol=math.sqrt(xp.finfo(dtype).eps))
  368. def test_input_validation(self, xp):
  369. # Test input validation for appropriate error messages
  370. message = '`func` must be callable.'
  371. bracket = xp.asarray(-4), xp.asarray(0), xp.asarray(4)
  372. with pytest.raises(ValueError, match=message):
  373. _chandrupatla_minimize(None, *bracket)
  374. message = 'Abscissae and function output must be real numbers.'
  375. bracket = xp.asarray(-4 + 1j), xp.asarray(0), xp.asarray(4)
  376. with pytest.raises(ValueError, match=message):
  377. _chandrupatla_minimize(lambda x: x, *bracket)
  378. message = "...be broadcast..."
  379. bracket = xp.asarray([-2, -3]), xp.asarray([0, 0]), xp.asarray([3, 4, 5])
  380. # raised by `np.broadcast, but the traceback is readable IMO
  381. with pytest.raises((ValueError, RuntimeError), match=message):
  382. _chandrupatla_minimize(lambda x: x, *bracket)
  383. message = "The shape of the array returned by `func` must be the same"
  384. bracket = xp.asarray([-3, -3]), xp.asarray([0, 0]), xp.asarray([5, 5])
  385. with pytest.raises(ValueError, match=message):
  386. _chandrupatla_minimize(lambda x: [x[0, ...], x[1, ...], x[1, ...]],
  387. *bracket)
  388. message = 'Tolerances must be non-negative scalars.'
  389. bracket = xp.asarray(-4), xp.asarray(0), xp.asarray(4)
  390. with pytest.raises(ValueError, match=message):
  391. _chandrupatla_minimize(lambda x: x, *bracket, xatol=-1)
  392. with pytest.raises(ValueError, match=message):
  393. _chandrupatla_minimize(lambda x: x, *bracket, xrtol=xp.nan)
  394. with pytest.raises(ValueError, match=message):
  395. _chandrupatla_minimize(lambda x: x, *bracket, fatol='ekki')
  396. with pytest.raises(ValueError, match=message):
  397. _chandrupatla_minimize(lambda x: x, *bracket, frtol=xp.nan)
  398. message = '`maxiter` must be a non-negative integer.'
  399. with pytest.raises(ValueError, match=message):
  400. _chandrupatla_minimize(lambda x: x, *bracket, maxiter=1.5)
  401. with pytest.raises(ValueError, match=message):
  402. _chandrupatla_minimize(lambda x: x, *bracket, maxiter=-1)
  403. message = '`callback` must be callable.'
  404. with pytest.raises(ValueError, match=message):
  405. _chandrupatla_minimize(lambda x: x, *bracket, callback='shrubbery')
  406. def test_bracket_order(self, xp):
  407. # Confirm that order of points in bracket doesn't
  408. loc = xp.linspace(-1, 1, 6)[:, xp.newaxis]
  409. brackets = xp.asarray(list(permutations([-5, 0, 5]))).T
  410. res = _chandrupatla_minimize(self.f, *brackets, args=(loc,))
  411. assert xp.all(xpx.isclose(res.x, loc) | (res.fun == self.f(loc, loc)))
  412. ref = res.x[:, 0] # all columns should be the same
  413. xp_assert_close(*xp.broadcast_arrays(res.x.T, ref), rtol=1e-15)
  414. def test_special_cases(self, xp):
  415. # Test edge cases and other special cases
  416. # Test that integers are not passed to `f`
  417. def f(x):
  418. assert xp.isdtype(x.dtype, "real floating")
  419. return (x - 1)**2
  420. bracket = xp.asarray(-7), xp.asarray(0), xp.asarray(8)
  421. with np.errstate(invalid='ignore'):
  422. res = _chandrupatla_minimize(f, *bracket, fatol=0, frtol=0)
  423. assert res.success
  424. xp_assert_close(res.x, xp.asarray(1.), rtol=1e-3)
  425. xp_assert_close(res.fun, xp.asarray(0.), atol=1e-200)
  426. # Test that if all elements of bracket equal minimizer, algorithm
  427. # reports convergence
  428. def f(x):
  429. return (x-1)**2
  430. bracket = xp.asarray(1), xp.asarray(1), xp.asarray(1)
  431. res = _chandrupatla_minimize(f, *bracket)
  432. assert res.success
  433. xp_assert_equal(res.x, xp.asarray(1.))
  434. # Test maxiter = 0. Should do nothing to bracket.
  435. def f(x):
  436. return (x-1)**2
  437. bracket = xp.asarray(-3), xp.asarray(1.1), xp.asarray(5)
  438. res = _chandrupatla_minimize(f, *bracket, maxiter=0)
  439. assert res.xl, res.xr == bracket
  440. assert res.nit == 0
  441. assert res.nfev == 3
  442. assert res.status == -2
  443. assert res.x == 1.1 # best so far
  444. # Test scalar `args` (not in tuple)
  445. def f(x, c):
  446. return (x-c)**2 - 1
  447. bracket = xp.asarray(-1), xp.asarray(0), xp.asarray(1)
  448. c = xp.asarray(1/3)
  449. res = _chandrupatla_minimize(f, *bracket, args=(c,))
  450. xp_assert_close(res.x, c)
  451. # Test zero tolerances
  452. def f(x):
  453. return -xp.sin(x)
  454. bracket = xp.asarray(0), xp.asarray(1), xp.asarray(xp.pi)
  455. res = _chandrupatla_minimize(f, *bracket, xatol=0, xrtol=0, fatol=0, frtol=0)
  456. assert res.success
  457. # found a minimum exactly (according to floating point arithmetic)
  458. assert res.xl < res.xm < res.xr
  459. assert f(res.xl) == f(res.xm) == f(res.xr)
  460. @make_xp_test_case(find_root)
  461. class TestFindRoot:
  462. def f(self, q, p):
  463. return special.ndtr(q) - p
  464. @pytest.mark.parametrize('p', [0.6, np.linspace(-0.05, 1.05, 10)])
  465. def test_basic(self, p, xp):
  466. # Invert distribution CDF and compare against distribution `ppf`
  467. a, b = xp.asarray(-5.), xp.asarray(5.)
  468. res = find_root(self.f, (a, b), args=(xp.asarray(p),))
  469. ref = xp.asarray(stats.norm().ppf(p), dtype=xp.asarray(p).dtype)
  470. xp_assert_close(res.x, ref)
  471. @pytest.mark.parametrize('shape', [tuple(), (12,), (3, 4), (3, 2, 2)])
  472. def test_vectorization(self, shape, xp):
  473. # Test for correct functionality, output shapes, and dtypes for various
  474. # input shapes.
  475. p = (np.linspace(-0.05, 1.05, 12).reshape(shape) if shape
  476. else np.float64(0.6))
  477. p_xp = xp.asarray(p)
  478. args_xp = (p_xp,)
  479. dtype = p_xp.dtype
  480. @np.vectorize
  481. def find_root_single(p):
  482. return find_root(self.f, (-5, 5), args=(p,))
  483. def f(*args, **kwargs):
  484. f.f_evals += 1
  485. return self.f(*args, **kwargs)
  486. f.f_evals = 0
  487. bracket = xp.asarray(-5., dtype=xp.float64), xp.asarray(5., dtype=xp.float64)
  488. res = find_root(f, bracket, args=args_xp)
  489. refs = find_root_single(p).ravel()
  490. ref_x = [ref.x for ref in refs]
  491. ref_x = xp.reshape(xp.asarray(ref_x, dtype=dtype), shape)
  492. xp_assert_close(res.x, ref_x)
  493. ref_f = [ref.f_x for ref in refs]
  494. ref_f = xp.reshape(xp.asarray(ref_f, dtype=dtype), shape)
  495. xp_assert_close(res.f_x, ref_f, atol=1e-15)
  496. xp_assert_equal(res.f_x, self.f(res.x, *args_xp))
  497. ref_success = [bool(ref.success) for ref in refs]
  498. ref_success = xp.reshape(xp.asarray(ref_success, dtype=xp.bool), shape)
  499. xp_assert_equal(res.success, ref_success)
  500. ref_status = [ref.status for ref in refs]
  501. ref_status = xp.reshape(xp.asarray(ref_status, dtype=xp.int32), shape)
  502. xp_assert_equal(res.status, ref_status)
  503. ref_nfev = [ref.nfev for ref in refs]
  504. ref_nfev = xp.reshape(xp.asarray(ref_nfev, dtype=xp.int32), shape)
  505. if is_numpy(xp):
  506. xp_assert_equal(res.nfev, ref_nfev)
  507. assert xp.max(res.nfev) == f.f_evals
  508. else: # different backend may lead to different nfev
  509. assert res.nfev.shape == shape
  510. assert res.nfev.dtype == xp.int32
  511. ref_nit = [ref.nit for ref in refs]
  512. ref_nit = xp.reshape(xp.asarray(ref_nit, dtype=xp.int32), shape)
  513. if is_numpy(xp):
  514. xp_assert_equal(res.nit, ref_nit)
  515. assert xp.max(res.nit) == f.f_evals-2
  516. else:
  517. assert res.nit.shape == shape
  518. assert res.nit.dtype == xp.int32
  519. ref_xl = [ref.bracket[0] for ref in refs]
  520. ref_xl = xp.reshape(xp.asarray(ref_xl, dtype=dtype), shape)
  521. xp_assert_close(res.bracket[0], ref_xl)
  522. ref_xr = [ref.bracket[1] for ref in refs]
  523. ref_xr = xp.reshape(xp.asarray(ref_xr, dtype=dtype), shape)
  524. xp_assert_close(res.bracket[1], ref_xr)
  525. xp_assert_less(res.bracket[0], res.bracket[1])
  526. finite = xp.isfinite(res.x)
  527. assert xp.all((res.x[finite] == res.bracket[0][finite])
  528. | (res.x[finite] == res.bracket[1][finite]))
  529. # PyTorch and CuPy don't solve to the same accuracy as NumPy - that's OK.
  530. atol = 1e-15 if is_numpy(xp) else 1e-9
  531. ref_fl = [ref.f_bracket[0] for ref in refs]
  532. ref_fl = xp.reshape(xp.asarray(ref_fl, dtype=dtype), shape)
  533. xp_assert_close(res.f_bracket[0], ref_fl, atol=atol)
  534. xp_assert_equal(res.f_bracket[0], self.f(res.bracket[0], *args_xp))
  535. ref_fr = [ref.f_bracket[1] for ref in refs]
  536. ref_fr = xp.reshape(xp.asarray(ref_fr, dtype=dtype), shape)
  537. xp_assert_close(res.f_bracket[1], ref_fr, atol=atol)
  538. xp_assert_equal(res.f_bracket[1], self.f(res.bracket[1], *args_xp))
  539. assert xp.all(xp.abs(res.f_x[finite]) ==
  540. xp.minimum(xp.abs(res.f_bracket[0][finite]),
  541. xp.abs(res.f_bracket[1][finite])))
  542. def test_flags(self, xp):
  543. # Test cases that should produce different status flags; show that all
  544. # can be produced simultaneously.
  545. def f(xs, js):
  546. # Note that full_like and int(j) shouldn't really be required. CuPy
  547. # is just really picky here, so I'm making it a special case to
  548. # make sure the other backends work when the user is less careful.
  549. assert js.dtype == xp.int64
  550. if is_cupy(xp):
  551. funcs = [lambda x: x - 2.5,
  552. lambda x: x - 10,
  553. lambda x: (x - 0.1)**3,
  554. lambda x: xp.full_like(x, xp.asarray(xp.nan))]
  555. return [funcs[int(j)](x) for x, j in zip(xs, js)]
  556. funcs = [lambda x: x - 2.5,
  557. lambda x: x - 10,
  558. lambda x: (x - 0.1) ** 3,
  559. lambda x: xp.nan]
  560. return [funcs[j](x) for x, j in zip(xs, js)]
  561. args = (xp.arange(4, dtype=xp.int64),)
  562. a, b = xp.asarray([0.]*4), xp.asarray([xp.pi]*4)
  563. res = find_root(f, (a, b), args=args, maxiter=2)
  564. ref_flags = xp.asarray([eim._ECONVERGED,
  565. eim._ESIGNERR,
  566. eim._ECONVERR,
  567. eim._EVALUEERR], dtype=xp.int32)
  568. xp_assert_equal(res.status, ref_flags)
  569. def test_convergence(self, xp):
  570. # Test that the convergence tolerances behave as expected
  571. rng = np.random.default_rng(2585255913088665241)
  572. p = xp.asarray(rng.random(size=3))
  573. bracket = (-xp.asarray(5.), xp.asarray(5.))
  574. args = (p,)
  575. kwargs0 = dict(args=args, tolerances=dict(xatol=0, xrtol=0, fatol=0, frtol=0))
  576. kwargs = deepcopy(kwargs0)
  577. kwargs['tolerances']['xatol'] = 1e-3
  578. res1 = find_root(self.f, bracket, **kwargs)
  579. xp_assert_less(res1.bracket[1] - res1.bracket[0],
  580. xp.full_like(p, xp.asarray(1e-3)))
  581. kwargs['tolerances']['xatol'] = 1e-6
  582. res2 = find_root(self.f, bracket, **kwargs)
  583. xp_assert_less(res2.bracket[1] - res2.bracket[0],
  584. xp.full_like(p, xp.asarray(1e-6)))
  585. xp_assert_less(res2.bracket[1] - res2.bracket[0],
  586. res1.bracket[1] - res1.bracket[0])
  587. kwargs = deepcopy(kwargs0)
  588. kwargs['tolerances']['xrtol'] = 1e-3
  589. res1 = find_root(self.f, bracket, **kwargs)
  590. xp_assert_less(res1.bracket[1] - res1.bracket[0], 1e-3 * xp.abs(res1.x))
  591. kwargs['tolerances']['xrtol'] = 1e-6
  592. res2 = find_root(self.f, bracket, **kwargs)
  593. xp_assert_less(res2.bracket[1] - res2.bracket[0],
  594. 1e-6 * xp.abs(res2.x))
  595. xp_assert_less(res2.bracket[1] - res2.bracket[0],
  596. res1.bracket[1] - res1.bracket[0])
  597. kwargs = deepcopy(kwargs0)
  598. kwargs['tolerances']['fatol'] = 1e-3
  599. res1 = find_root(self.f, bracket, **kwargs)
  600. xp_assert_less(xp.abs(res1.f_x), xp.full_like(p, xp.asarray(1e-3)))
  601. kwargs['tolerances']['fatol'] = 1e-6
  602. res2 = find_root(self.f, bracket, **kwargs)
  603. xp_assert_less(xp.abs(res2.f_x), xp.full_like(p, xp.asarray(1e-6)))
  604. xp_assert_less(xp.abs(res2.f_x), xp.abs(res1.f_x))
  605. kwargs = deepcopy(kwargs0)
  606. kwargs['tolerances']['frtol'] = 1e-3
  607. x1, x2 = bracket
  608. f0 = xp.minimum(xp.abs(self.f(x1, *args)), xp.abs(self.f(x2, *args)))
  609. res1 = find_root(self.f, bracket, **kwargs)
  610. xp_assert_less(xp.abs(res1.f_x), 1e-3*f0)
  611. kwargs['tolerances']['frtol'] = 1e-6
  612. res2 = find_root(self.f, bracket, **kwargs)
  613. xp_assert_less(xp.abs(res2.f_x), 1e-6*f0)
  614. xp_assert_less(xp.abs(res2.f_x), xp.abs(res1.f_x))
  615. def test_maxiter_callback(self, xp):
  616. # Test behavior of `maxiter` parameter and `callback` interface
  617. p = xp.asarray(0.612814)
  618. bracket = (xp.asarray(-5.), xp.asarray(5.))
  619. maxiter = 5
  620. def f(q, p):
  621. res = special.ndtr(q) - p
  622. f.x = q
  623. f.f_x = res
  624. return res
  625. f.x = None
  626. f.f_x = None
  627. res = find_root(f, bracket, args=(p,), maxiter=maxiter)
  628. assert not xp.any(res.success)
  629. assert xp.all(res.nfev == maxiter+2)
  630. assert xp.all(res.nit == maxiter)
  631. def callback(res):
  632. callback.iter += 1
  633. callback.res = res
  634. assert hasattr(res, 'x')
  635. if callback.iter == 0:
  636. # callback is called once with initial bracket
  637. assert (res.bracket[0], res.bracket[1]) == bracket
  638. else:
  639. changed = (((res.bracket[0] == callback.bracket[0])
  640. & (res.bracket[1] != callback.bracket[1]))
  641. | ((res.bracket[0] != callback.bracket[0])
  642. & (res.bracket[1] == callback.bracket[1])))
  643. assert xp.all(changed)
  644. callback.bracket[0] = res.bracket[0]
  645. callback.bracket[1] = res.bracket[1]
  646. assert res.status == eim._EINPROGRESS
  647. xp_assert_equal(self.f(res.bracket[0], p), res.f_bracket[0])
  648. xp_assert_equal(self.f(res.bracket[1], p), res.f_bracket[1])
  649. xp_assert_equal(self.f(res.x, p), res.f_x)
  650. if callback.iter == maxiter:
  651. raise StopIteration
  652. callback.iter = -1 # callback called once before first iteration
  653. callback.res = None
  654. callback.bracket = [None, None]
  655. res2 = find_root(f, bracket, args=(p,), callback=callback)
  656. # terminating with callback is identical to terminating due to maxiter
  657. # (except for `status`)
  658. for key in res.keys():
  659. if key == 'status':
  660. xp_assert_equal(res[key], xp.asarray(eim._ECONVERR, dtype=xp.int32))
  661. xp_assert_equal(res2[key], xp.asarray(eim._ECALLBACK, dtype=xp.int32))
  662. elif key in {'bracket', 'f_bracket'}:
  663. xp_assert_equal(res2[key][0], res[key][0])
  664. xp_assert_equal(res2[key][1], res[key][1])
  665. elif key.startswith('_'):
  666. continue
  667. else:
  668. xp_assert_equal(res2[key], res[key])
  669. @pytest.mark.parametrize('case', _CHANDRUPATLA_TESTS)
  670. def test_nit_expected(self, case, xp):
  671. # Test that `_chandrupatla` implements Chandrupatla's algorithm:
  672. # in all 40 test cases, the number of iterations performed
  673. # matches the number reported in the original paper.
  674. f, bracket, root, nfeval, id = case
  675. # Chandrupatla's criterion is equivalent to
  676. # abs(x2-x1) < 4*abs(xmin)*xrtol + xatol, but we use the more standard
  677. # abs(x2-x1) < abs(xmin)*xrtol + xatol. Therefore, set xrtol to 4x
  678. # that used by Chandrupatla in tests.
  679. bracket = (xp.asarray(bracket[0], dtype=xp.float64),
  680. xp.asarray(bracket[1], dtype=xp.float64))
  681. root = xp.asarray(root, dtype=xp.float64)
  682. res = find_root(f, bracket, tolerances=dict(xrtol=4e-10, xatol=1e-5))
  683. xp_assert_close(res.f_x, xp.asarray(f(root), dtype=xp.float64),
  684. rtol=1e-8, atol=2e-3)
  685. xp_assert_equal(res.nfev, xp.asarray(nfeval, dtype=xp.int32))
  686. @pytest.mark.parametrize("root", (0.622, [0.622, 0.623]))
  687. @pytest.mark.parametrize("dtype", ('float16', 'float32', 'float64'))
  688. def test_dtype(self, root, dtype, xp):
  689. # Test that dtypes are preserved
  690. not_numpy = not is_numpy(xp)
  691. if not_numpy and dtype == 'float16':
  692. pytest.skip("`float16` dtype only supported for NumPy arrays.")
  693. dtype = getattr(xp, dtype, None)
  694. if dtype is None:
  695. pytest.skip(f"{xp} does not support {dtype}")
  696. def f(x, root):
  697. res = (x - root) ** 3.
  698. if is_numpy(xp): # NumPy does not preserve dtype
  699. return xp.asarray(res, dtype=dtype)
  700. return res
  701. a, b = xp.asarray(-3, dtype=dtype), xp.asarray(3, dtype=dtype)
  702. root = xp.asarray(root, dtype=dtype)
  703. res = find_root(f, (a, b), args=(root,), tolerances={'xatol': 1e-3})
  704. try:
  705. xp_assert_close(res.x, root, atol=1e-3)
  706. except AssertionError:
  707. assert res.x.dtype == dtype
  708. xp.all(res.f_x == 0)
  709. def test_input_validation(self, xp):
  710. # Test input validation for appropriate error messages
  711. def func(x):
  712. return x
  713. message = '`func` must be callable.'
  714. with pytest.raises(ValueError, match=message):
  715. bracket = xp.asarray(-4), xp.asarray(4)
  716. find_root(None, bracket)
  717. message = 'Abscissae and function output must be real numbers.'
  718. with pytest.raises(ValueError, match=message):
  719. bracket = xp.asarray(-4+1j), xp.asarray(4)
  720. find_root(func, bracket)
  721. # raised by `np.broadcast, but the traceback is readable IMO
  722. # all messages include this part
  723. message = "(not be broadcast|Attempting to broadcast a dimension of length)"
  724. with pytest.raises((ValueError, RuntimeError), match=message):
  725. bracket = xp.asarray([-2, -3]), xp.asarray([3, 4, 5])
  726. find_root(func, bracket)
  727. message = "The shape of the array returned by `func`..."
  728. with pytest.raises(ValueError, match=message):
  729. bracket = xp.asarray([-3, -3]), xp.asarray([5, 5])
  730. find_root(lambda x: [x[0], x[1], x[1]], bracket)
  731. message = 'Tolerances must be non-negative scalars.'
  732. bracket = xp.asarray(-4), xp.asarray(4)
  733. with pytest.raises(ValueError, match=message):
  734. find_root(func, bracket, tolerances=dict(xatol=-1))
  735. with pytest.raises(ValueError, match=message):
  736. find_root(func, bracket, tolerances=dict(xrtol=xp.nan))
  737. with pytest.raises(ValueError, match=message):
  738. find_root(func, bracket, tolerances=dict(fatol='ekki'))
  739. with pytest.raises(ValueError, match=message):
  740. find_root(func, bracket, tolerances=dict(frtol=xp.nan))
  741. message = '`maxiter` must be a non-negative integer.'
  742. with pytest.raises(ValueError, match=message):
  743. find_root(func, bracket, maxiter=1.5)
  744. with pytest.raises(ValueError, match=message):
  745. find_root(func, bracket, maxiter=-1)
  746. message = '`callback` must be callable.'
  747. with pytest.raises(ValueError, match=message):
  748. find_root(func, bracket, callback='shrubbery')
  749. def test_special_cases(self, xp):
  750. # Test edge cases and other special cases
  751. # Test infinite function values
  752. def f(x):
  753. return 1 / x + 1 - 1 / (-x + 1)
  754. a, b = xp.asarray([0.1, 0., 0., 0.1]), xp.asarray([0.9, 1.0, 0.9, 1.0])
  755. with np.errstate(divide='ignore', invalid='ignore'):
  756. res = find_root(f, (a, b))
  757. assert xp.all(res.success)
  758. xp_assert_close(res.x[1:], xp.full((3,), res.x[0]))
  759. # Test that integers are not passed to `f`
  760. # (otherwise this would overflow)
  761. def f(x):
  762. assert xp.isdtype(x.dtype, "real floating")
  763. # this would overflow if x were an xp integer dtype
  764. return x ** 31 - 1
  765. # note that all inputs are integer type; result is automatically default float
  766. res = find_root(f, (xp.asarray(-7), xp.asarray(5)))
  767. assert res.success
  768. xp_assert_close(res.x, xp.asarray(1.))
  769. # Test that if both ends of bracket equal root, algorithm reports
  770. # convergence.
  771. def f(x, root):
  772. return x**2 - root
  773. root = xp.asarray([0, 1])
  774. res = find_root(f, (xp.asarray(1), xp.asarray(1)), args=(root,))
  775. xp_assert_equal(res.success, xp.asarray([False, True]))
  776. xp_assert_equal(res.x, xp.asarray([xp.nan, 1.]))
  777. def f(x):
  778. return 1/x
  779. with np.errstate(invalid='ignore'):
  780. inf = xp.asarray(xp.inf)
  781. res = find_root(f, (inf, inf))
  782. assert res.success
  783. xp_assert_equal(res.x, xp.asarray(xp.inf))
  784. # Test maxiter = 0. Should do nothing to bracket.
  785. def f(x):
  786. return x**3 - 1
  787. a, b = xp.asarray(-3.), xp.asarray(5.)
  788. res = find_root(f, (a, b), maxiter=0)
  789. xp_assert_equal(res.success, xp.asarray(False))
  790. xp_assert_equal(res.status, xp.asarray(-2, dtype=xp.int32))
  791. xp_assert_equal(res.nit, xp.asarray(0, dtype=xp.int32))
  792. xp_assert_equal(res.nfev, xp.asarray(2, dtype=xp.int32))
  793. xp_assert_equal(res.bracket[0], a)
  794. xp_assert_equal(res.bracket[1], b)
  795. # The `x` attribute is the one with the smaller function value
  796. xp_assert_equal(res.x, a)
  797. # Reverse bracket; check that this is still true
  798. res = find_root(f, (-b, -a), maxiter=0)
  799. xp_assert_equal(res.x, -a)
  800. # Test maxiter = 1
  801. res = find_root(f, (a, b), maxiter=1)
  802. xp_assert_equal(res.success, xp.asarray(True))
  803. xp_assert_equal(res.status, xp.asarray(0, dtype=xp.int32))
  804. xp_assert_equal(res.nit, xp.asarray(1, dtype=xp.int32))
  805. xp_assert_equal(res.nfev, xp.asarray(3, dtype=xp.int32))
  806. xp_assert_close(res.x, xp.asarray(1.))
  807. # Test scalar `args` (not in tuple)
  808. def f(x, c):
  809. return c*x - 1
  810. res = find_root(f, (xp.asarray(-1), xp.asarray(1)), args=xp.asarray(3))
  811. xp_assert_close(res.x, xp.asarray(1/3))
  812. # # TODO: Test zero tolerance
  813. # # ~~What's going on here - why are iterations repeated?~~
  814. # # tl goes to zero when xatol=xrtol=0. When function is nearly linear,
  815. # # this causes convergence issues.
  816. # def f(x):
  817. # return np.cos(x)
  818. #
  819. # res = _chandrupatla_root(f, 0, np.pi, xatol=0, xrtol=0)
  820. # assert res.nit < 100
  821. # xp = np.nextafter(res.x, np.inf)
  822. # xm = np.nextafter(res.x, -np.inf)
  823. # assert np.abs(res.fun) < np.abs(f(xp))
  824. # assert np.abs(res.fun) < np.abs(f(xm))