test_tanhsinh.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158
  1. # mypy: disable-error-code="attr-defined"
  2. import os
  3. import pytest
  4. import math
  5. import numpy as np
  6. from numpy.testing import assert_allclose
  7. import scipy._lib._elementwise_iterative_method as eim
  8. from scipy._lib._array_api_no_0d import xp_assert_close, xp_assert_equal
  9. from scipy._lib._array_api import (array_namespace, xp_size, xp_ravel, xp_copy,
  10. is_numpy, make_xp_test_case)
  11. from scipy import special, stats
  12. from scipy.integrate import quad_vec, nsum, tanhsinh as _tanhsinh
  13. from scipy.integrate._tanhsinh import _pair_cache
  14. from scipy.special._ufuncs import _gen_harmonic
  15. def norm_pdf(x, xp=None):
  16. xp = array_namespace(x) if xp is None else xp
  17. return 1/(2*xp.pi)**0.5 * xp.exp(-x**2/2)
  18. def norm_logpdf(x, xp=None):
  19. xp = array_namespace(x) if xp is None else xp
  20. return -0.5*math.log(2*xp.pi) - x**2/2
  21. def _vectorize(xp):
  22. # xp-compatible version of np.vectorize
  23. # assumes arguments are all arrays of the same shape
  24. def decorator(f):
  25. def wrapped(*arg_arrays):
  26. shape = arg_arrays[0].shape
  27. arg_arrays = [xp_ravel(arg_array) for arg_array in arg_arrays]
  28. res = []
  29. for i in range(math.prod(shape)):
  30. arg_scalars = [arg_array[i] for arg_array in arg_arrays]
  31. res.append(f(*arg_scalars))
  32. return res
  33. return wrapped
  34. return decorator
  35. @make_xp_test_case(_tanhsinh)
  36. class TestTanhSinh:
  37. # Test problems from [1] Section 6
  38. def f1(self, t):
  39. return t * np.log(1 + t)
  40. f1.ref = 0.25
  41. f1.b = 1
  42. def f2(self, t):
  43. return t ** 2 * np.arctan(t)
  44. f2.ref = (np.pi - 2 + 2 * np.log(2)) / 12
  45. f2.b = 1
  46. def f3(self, t):
  47. return np.exp(t) * np.cos(t)
  48. f3.ref = (np.exp(np.pi / 2) - 1) / 2
  49. f3.b = np.pi / 2
  50. def f4(self, t):
  51. a = np.sqrt(2 + t ** 2)
  52. return np.arctan(a) / ((1 + t ** 2) * a)
  53. f4.ref = 5 * np.pi ** 2 / 96
  54. f4.b = 1
  55. def f5(self, t):
  56. return np.sqrt(t) * np.log(t)
  57. f5.ref = -4 / 9
  58. f5.b = 1
  59. def f6(self, t):
  60. return np.sqrt(1 - t ** 2)
  61. f6.ref = np.pi / 4
  62. f6.b = 1
  63. def f7(self, t):
  64. return np.sqrt(t) / np.sqrt(1 - t ** 2)
  65. f7.ref = 2 * np.sqrt(np.pi) * special.gamma(3 / 4) / special.gamma(1 / 4)
  66. f7.b = 1
  67. def f8(self, t):
  68. return np.log(t) ** 2
  69. f8.ref = 2
  70. f8.b = 1
  71. def f9(self, t):
  72. return np.log(np.cos(t))
  73. f9.ref = -np.pi * np.log(2) / 2
  74. f9.b = np.pi / 2
  75. def f10(self, t):
  76. return np.sqrt(np.tan(t))
  77. f10.ref = np.pi * np.sqrt(2) / 2
  78. f10.b = np.pi / 2
  79. def f11(self, t):
  80. return 1 / (1 + t ** 2)
  81. f11.ref = np.pi / 2
  82. f11.b = np.inf
  83. def f12(self, t):
  84. return np.exp(-t) / np.sqrt(t)
  85. f12.ref = np.sqrt(np.pi)
  86. f12.b = np.inf
  87. def f13(self, t):
  88. return np.exp(-t ** 2 / 2)
  89. f13.ref = np.sqrt(np.pi / 2)
  90. f13.b = np.inf
  91. def f14(self, t):
  92. return np.exp(-t) * np.cos(t)
  93. f14.ref = 0.5
  94. f14.b = np.inf
  95. def f15(self, t):
  96. return np.sin(t) / t
  97. f15.ref = np.pi / 2
  98. f15.b = np.inf
  99. def error(self, res, ref, log=False, xp=None):
  100. xp = array_namespace(res, ref) if xp is None else xp
  101. err = abs(res - ref)
  102. if not log:
  103. return err
  104. with np.errstate(divide='ignore'):
  105. return xp.log10(err)
  106. def test_input_validation(self, xp):
  107. f = self.f1
  108. zero = xp.asarray(0)
  109. f_b = xp.asarray(f.b)
  110. message = '`f` must be callable.'
  111. with pytest.raises(ValueError, match=message):
  112. _tanhsinh(42, zero, f_b)
  113. message = '...must be True or False.'
  114. with pytest.raises(ValueError, match=message):
  115. _tanhsinh(f, zero, f_b, log=2)
  116. message = '...must be real numbers.'
  117. with pytest.raises(ValueError, match=message):
  118. _tanhsinh(f, xp.asarray(1+1j), f_b)
  119. with pytest.raises(ValueError, match=message):
  120. _tanhsinh(f, zero, f_b, atol='ekki')
  121. with pytest.raises(ValueError, match=message):
  122. _tanhsinh(f, zero, f_b, rtol=pytest)
  123. message = '...must be non-negative and finite.'
  124. with pytest.raises(ValueError, match=message):
  125. _tanhsinh(f, zero, f_b, rtol=-1)
  126. with pytest.raises(ValueError, match=message):
  127. _tanhsinh(f, zero, f_b, atol=xp.inf)
  128. message = '...may not be positive infinity.'
  129. with pytest.raises(ValueError, match=message):
  130. _tanhsinh(f, zero, f_b, rtol=xp.inf, log=True)
  131. with pytest.raises(ValueError, match=message):
  132. _tanhsinh(f, zero, f_b, atol=xp.inf, log=True)
  133. message = '...must be integers.'
  134. with pytest.raises(ValueError, match=message):
  135. _tanhsinh(f, zero, f_b, maxlevel=object())
  136. # with pytest.raises(ValueError, match=message): # unused for now
  137. # _tanhsinh(f, zero, f_b, maxfun=1+1j)
  138. with pytest.raises(ValueError, match=message):
  139. _tanhsinh(f, zero, f_b, minlevel="migratory coconut")
  140. message = '...must be non-negative.'
  141. with pytest.raises(ValueError, match=message):
  142. _tanhsinh(f, zero, f_b, maxlevel=-1)
  143. # with pytest.raises(ValueError, match=message): # unused for now
  144. # _tanhsinh(f, zero, f_b, maxfun=-1)
  145. with pytest.raises(ValueError, match=message):
  146. _tanhsinh(f, zero, f_b, minlevel=-1)
  147. message = '...must be True or False.'
  148. with pytest.raises(ValueError, match=message):
  149. _tanhsinh(f, zero, f_b, preserve_shape=2)
  150. message = '...must be callable.'
  151. with pytest.raises(ValueError, match=message):
  152. _tanhsinh(f, zero, f_b, callback='elderberry')
  153. @pytest.mark.parametrize("limits, ref", [
  154. [(0, math.inf), 0.5], # b infinite
  155. [(-math.inf, 0), 0.5], # a infinite
  156. [(-math.inf, math.inf), 1.], # a and b infinite
  157. [(math.inf, -math.inf), -1.], # flipped limits
  158. [(1, -1), stats.norm.cdf(-1.) - stats.norm.cdf(1.)], # flipped limits
  159. ])
  160. def test_integral_transforms(self, limits, ref, xp):
  161. # Check that the integral transforms are behaving for both normal and
  162. # log integration
  163. limits = [xp.asarray(limit) for limit in limits]
  164. dtype = xp.asarray(float(limits[0])).dtype
  165. ref = xp.asarray(ref, dtype=dtype)
  166. res = _tanhsinh(norm_pdf, *limits)
  167. xp_assert_close(res.integral, ref)
  168. logres = _tanhsinh(norm_logpdf, *limits, log=True)
  169. xp_assert_close(xp.exp(logres.integral), ref, check_dtype=False)
  170. # Transformation should not make the result complex unnecessarily
  171. assert (xp.isdtype(logres.integral.dtype, "real floating") if ref > 0
  172. else xp.isdtype(logres.integral.dtype, "complex floating"))
  173. atol = 2 * xp.finfo(res.error.dtype).eps
  174. xp_assert_close(xp.exp(logres.error), res.error, atol=atol, check_dtype=False)
  175. # 15 skipped intentionally; it's very difficult numerically
  176. @pytest.mark.skip_xp_backends(np_only=True,
  177. reason='Cumbersome to convert everything.')
  178. @pytest.mark.parametrize('f_number', range(1, 15))
  179. def test_basic(self, f_number, xp):
  180. f = getattr(self, f"f{f_number}")
  181. rtol = 2e-8
  182. res = _tanhsinh(f, 0, f.b, rtol=rtol)
  183. assert_allclose(res.integral, f.ref, rtol=rtol)
  184. if f_number not in {7, 12, 14}: # mildly underestimates error here
  185. true_error = abs(self.error(res.integral, f.ref)/res.integral)
  186. assert true_error < res.error
  187. if f_number in {7, 10, 12}: # succeeds, but doesn't know it
  188. return
  189. assert res.success
  190. assert res.status == 0
  191. @pytest.mark.skip_xp_backends(np_only=True,
  192. reason="Distributions aren't xp-compatible.")
  193. @pytest.mark.parametrize('ref', (0.5, [0.4, 0.6]))
  194. @pytest.mark.parametrize('case', stats._distr_params.distcont)
  195. def test_accuracy(self, ref, case, xp):
  196. distname, params = case
  197. if distname in {'dgamma', 'dweibull', 'laplace', 'kstwo'}:
  198. # should split up interval at first-derivative discontinuity
  199. pytest.skip('tanh-sinh is not great for non-smooth integrands')
  200. if (distname in {'studentized_range', 'levy_stable'}
  201. and not int(os.getenv('SCIPY_XSLOW', 0))):
  202. pytest.skip('This case passes, but it is too slow.')
  203. dist = getattr(stats, distname)(*params)
  204. x = dist.interval(ref)
  205. res = _tanhsinh(dist.pdf, *x)
  206. assert_allclose(res.integral, ref)
  207. @pytest.mark.parametrize('shape', [tuple(), (12,), (3, 4), (3, 2, 2)])
  208. def test_vectorization(self, shape, xp):
  209. # Test for correct functionality, output shapes, and dtypes for various
  210. # input shapes.
  211. rng = np.random.default_rng(82456839535679456794)
  212. a = xp.asarray(rng.random(shape))
  213. b = xp.asarray(rng.random(shape))
  214. p = xp.asarray(rng.random(shape))
  215. n = math.prod(shape)
  216. def f(x, p):
  217. f.ncall += 1
  218. f.feval += 1 if (xp_size(x) == n or x.ndim <= 1) else x.shape[-1]
  219. return x**p
  220. f.ncall = 0
  221. f.feval = 0
  222. @_vectorize(xp)
  223. def _tanhsinh_single(a, b, p):
  224. return _tanhsinh(lambda x: x**p, a, b)
  225. res = _tanhsinh(f, a, b, args=(p,))
  226. refs = _tanhsinh_single(a, b, p)
  227. attrs = ['integral', 'error', 'success', 'status', 'nfev', 'maxlevel']
  228. for attr in attrs:
  229. ref_attr = xp.stack([getattr(ref, attr) for ref in refs])
  230. res_attr = xp_ravel(getattr(res, attr))
  231. xp_assert_close(res_attr, ref_attr, rtol=1e-15)
  232. assert getattr(res, attr).shape == shape
  233. assert xp.isdtype(res.success.dtype, 'bool')
  234. assert xp.isdtype(res.status.dtype, 'integral')
  235. assert xp.isdtype(res.nfev.dtype, 'integral')
  236. assert xp.isdtype(res.maxlevel.dtype, 'integral')
  237. assert xp.max(res.nfev) == f.feval
  238. # maxlevel = 2 -> 3 function calls (2 initialization, 1 work)
  239. assert xp.max(res.maxlevel) >= 2
  240. assert xp.max(res.maxlevel) == f.ncall
  241. def test_flags(self, xp):
  242. # Test cases that should produce different status flags; show that all
  243. # can be produced simultaneously.
  244. def f(xs, js):
  245. f.nit += 1
  246. funcs = [lambda x: xp.exp(-x**2), # converges
  247. lambda x: xp.exp(x), # reaches maxiter due to order=2
  248. lambda x: xp.full_like(x, xp.nan)] # stops due to NaN
  249. res = []
  250. for i in range(xp_size(js)):
  251. x = xs[i, ...]
  252. j = int(xp_ravel(js)[i])
  253. res.append(funcs[j](x))
  254. return xp.stack(res)
  255. f.nit = 0
  256. args = (xp.arange(3, dtype=xp.int64),)
  257. a = xp.asarray([xp.inf]*3)
  258. b = xp.asarray([-xp.inf] * 3)
  259. res = _tanhsinh(f, a, b, maxlevel=5, args=args)
  260. ref_flags = xp.asarray([0, -2, -3], dtype=xp.int32)
  261. xp_assert_equal(res.status, ref_flags)
  262. def test_flags_preserve_shape(self, xp):
  263. # Same test as above but using `preserve_shape` option to simplify.
  264. def f(x):
  265. res = [xp.exp(-x[0]**2), # converges
  266. xp.exp(x[1]), # reaches maxiter due to order=2
  267. xp.full_like(x[2], xp.nan)] # stops due to NaN
  268. return xp.stack(res)
  269. a = xp.asarray([xp.inf] * 3)
  270. b = xp.asarray([-xp.inf] * 3)
  271. res = _tanhsinh(f, a, b, maxlevel=5, preserve_shape=True)
  272. ref_flags = xp.asarray([0, -2, -3], dtype=xp.int32)
  273. xp_assert_equal(res.status, ref_flags)
  274. def test_preserve_shape(self, xp):
  275. # Test `preserve_shape` option
  276. def f(x, xp):
  277. return xp.stack([xp.stack([x, xp.sin(10 * x)]),
  278. xp.stack([xp.cos(30 * x), x * xp.sin(100 * x)])])
  279. ref = quad_vec(lambda x: f(x, np), 0, 1)
  280. res = _tanhsinh(lambda x: f(x, xp), xp.asarray(0), xp.asarray(1),
  281. preserve_shape=True)
  282. dtype = xp.asarray(0.).dtype
  283. xp_assert_close(res.integral, xp.asarray(ref[0], dtype=dtype))
  284. def test_convergence(self, xp):
  285. # demonstrate that number of accurate digits doubles each iteration
  286. dtype = xp.float64 # this only works with good precision
  287. def f(t):
  288. return t * xp.log(1 + t)
  289. ref = xp.asarray(0.25, dtype=dtype)
  290. a, b = xp.asarray(0., dtype=dtype), xp.asarray(1., dtype=dtype)
  291. last_logerr = 0
  292. for i in range(4):
  293. res = _tanhsinh(f, a, b, minlevel=0, maxlevel=i)
  294. logerr = self.error(res.integral, ref, log=True, xp=xp)
  295. assert (logerr < last_logerr * 2 or logerr < -15.5)
  296. last_logerr = logerr
  297. def test_options_and_result_attributes(self, xp):
  298. # demonstrate that options are behaving as advertised and status
  299. # messages are as intended
  300. def f(x):
  301. f.calls += 1
  302. f.feval += xp_size(xp.asarray(x))
  303. return x**2 * xp.atan(x)
  304. f.ref = xp.asarray((math.pi - 2 + 2 * math.log(2)) / 12, dtype=xp.float64)
  305. default_rtol = 1e-12
  306. default_atol = f.ref * default_rtol # effective default absolute tol
  307. # Keep things simpler by leaving tolerances fixed rather than
  308. # having to make them dtype-dependent
  309. a = xp.asarray(0., dtype=xp.float64)
  310. b = xp.asarray(1., dtype=xp.float64)
  311. # Test default options
  312. f.feval, f.calls = 0, 0
  313. ref = _tanhsinh(f, a, b)
  314. assert self.error(ref.integral, f.ref) < ref.error < default_atol
  315. assert ref.nfev == f.feval
  316. ref.calls = f.calls # reference number of function calls
  317. assert ref.success
  318. assert ref.status == 0
  319. # Test `maxlevel` equal to required max level
  320. # We should get all the same results
  321. f.feval, f.calls = 0, 0
  322. maxlevel = int(ref.maxlevel)
  323. res = _tanhsinh(f, a, b, maxlevel=maxlevel)
  324. res.calls = f.calls
  325. assert res == ref
  326. # Now reduce the maximum level. We won't meet tolerances.
  327. f.feval, f.calls = 0, 0
  328. maxlevel -= 1
  329. assert maxlevel >= 2 # can't compare errors otherwise
  330. res = _tanhsinh(f, a, b, maxlevel=maxlevel)
  331. assert self.error(res.integral, f.ref) < res.error > default_atol
  332. assert res.nfev == f.feval < ref.nfev
  333. assert f.calls == ref.calls - 1
  334. assert not res.success
  335. assert res.status == eim._ECONVERR
  336. # `maxfun` is currently not enforced
  337. # # Test `maxfun` equal to required number of function evaluations
  338. # # We should get all the same results
  339. # f.feval, f.calls = 0, 0
  340. # maxfun = ref.nfev
  341. # res = _tanhsinh(f, 0, f.b, maxfun = maxfun)
  342. # assert res == ref
  343. #
  344. # # Now reduce `maxfun`. We won't meet tolerances.
  345. # f.feval, f.calls = 0, 0
  346. # maxfun -= 1
  347. # res = _tanhsinh(f, 0, f.b, maxfun=maxfun)
  348. # assert self.error(res.integral, f.ref) < res.error > default_atol
  349. # assert res.nfev == f.feval < ref.nfev
  350. # assert f.calls == ref.calls - 1
  351. # assert not res.success
  352. # assert res.status == 2
  353. # Take this result to be the new reference
  354. ref = res
  355. ref.calls = f.calls
  356. # Test `atol`
  357. f.feval, f.calls = 0, 0
  358. # With this tolerance, we should get the exact same result as ref
  359. atol = np.nextafter(float(ref.error), np.inf)
  360. res = _tanhsinh(f, a, b, rtol=0, atol=atol)
  361. assert res.integral == ref.integral
  362. assert res.error == ref.error
  363. assert res.nfev == f.feval == ref.nfev
  364. assert f.calls == ref.calls
  365. # Except the result is considered to be successful
  366. assert res.success
  367. assert res.status == 0
  368. f.feval, f.calls = 0, 0
  369. # With a tighter tolerance, we should get a more accurate result
  370. atol = np.nextafter(float(ref.error), -np.inf)
  371. res = _tanhsinh(f, a, b, rtol=0, atol=atol)
  372. assert self.error(res.integral, f.ref) < res.error < atol
  373. assert res.nfev == f.feval > ref.nfev
  374. assert f.calls > ref.calls
  375. assert res.success
  376. assert res.status == 0
  377. # Test `rtol`
  378. f.feval, f.calls = 0, 0
  379. # With this tolerance, we should get the exact same result as ref
  380. rtol = np.nextafter(float(ref.error/ref.integral), np.inf)
  381. res = _tanhsinh(f, a, b, rtol=rtol)
  382. assert res.integral == ref.integral
  383. assert res.error == ref.error
  384. assert res.nfev == f.feval == ref.nfev
  385. assert f.calls == ref.calls
  386. # Except the result is considered to be successful
  387. assert res.success
  388. assert res.status == 0
  389. f.feval, f.calls = 0, 0
  390. # With a tighter tolerance, we should get a more accurate result
  391. rtol = np.nextafter(float(ref.error/ref.integral), -np.inf)
  392. res = _tanhsinh(f, a, b, rtol=rtol)
  393. assert self.error(res.integral, f.ref)/f.ref < res.error/res.integral < rtol
  394. assert res.nfev == f.feval > ref.nfev
  395. assert f.calls > ref.calls
  396. assert res.success
  397. assert res.status == 0
  398. @pytest.mark.skip_xp_backends('torch', reason=
  399. 'https://github.com/scipy/scipy/pull/21149#issuecomment-2330477359',
  400. )
  401. @pytest.mark.parametrize('rtol', [1e-4, 1e-14])
  402. def test_log(self, rtol, xp):
  403. # Test equivalence of log-integration and regular integration
  404. test_tols = dict(atol=1e-18, rtol=1e-15)
  405. # Positive integrand (real log-integrand)
  406. a = xp.asarray(-1., dtype=xp.float64)
  407. b = xp.asarray(2., dtype=xp.float64)
  408. res = _tanhsinh(norm_logpdf, a, b, log=True, rtol=math.log(rtol))
  409. ref = _tanhsinh(norm_pdf, a, b, rtol=rtol)
  410. xp_assert_close(xp.exp(res.integral), ref.integral, **test_tols)
  411. xp_assert_close(xp.exp(res.error), ref.error, **test_tols)
  412. assert res.nfev == ref.nfev
  413. # Real integrand (complex log-integrand)
  414. def f(x):
  415. return -norm_logpdf(x)*norm_pdf(x)
  416. def logf(x):
  417. return xp.log(norm_logpdf(x) + 0j) + norm_logpdf(x) + xp.pi * 1j
  418. a = xp.asarray(-xp.inf, dtype=xp.float64)
  419. b = xp.asarray(xp.inf, dtype=xp.float64)
  420. res = _tanhsinh(logf, a, b, log=True)
  421. ref = _tanhsinh(f, a, b)
  422. # In gh-19173, we saw `invalid` warnings on one CI platform.
  423. # Silencing `all` because I can't reproduce locally and don't want
  424. # to risk the need to run CI again.
  425. with np.errstate(all='ignore'):
  426. xp_assert_close(xp.exp(res.integral), ref.integral, **test_tols,
  427. check_dtype=False)
  428. xp_assert_close(xp.exp(res.error), ref.error, **test_tols,
  429. check_dtype=False)
  430. assert res.nfev == ref.nfev
  431. def test_complex(self, xp):
  432. # Test integration of complex integrand
  433. # Finite limits
  434. def f(x):
  435. return xp.exp(1j * x)
  436. a, b = xp.asarray(0.), xp.asarray(xp.pi/4)
  437. res = _tanhsinh(f, a, b)
  438. ref = math.sqrt(2)/2 + (1-math.sqrt(2)/2)*1j
  439. xp_assert_close(res.integral, xp.asarray(ref))
  440. # Infinite limits
  441. def f(x):
  442. return norm_pdf(x) + 1j/2*norm_pdf(x/2)
  443. a, b = xp.asarray(xp.inf), xp.asarray(-xp.inf)
  444. res = _tanhsinh(f, a, b)
  445. xp_assert_close(res.integral, xp.asarray(-(1+1j)))
  446. @pytest.mark.parametrize("maxlevel", range(4))
  447. def test_minlevel(self, maxlevel, xp):
  448. # Verify that minlevel does not change the values at which the
  449. # integrand is evaluated or the integral/error estimates, only the
  450. # number of function calls
  451. def f(x):
  452. f.calls += 1
  453. f.feval += xp_size(xp.asarray(x))
  454. f.x = xp.concat((f.x, xp_ravel(x)))
  455. return x**2 * xp.atan(x)
  456. f.feval, f.calls, f.x = 0, 0, xp.asarray([])
  457. a = xp.asarray(0, dtype=xp.float64)
  458. b = xp.asarray(1, dtype=xp.float64)
  459. ref = _tanhsinh(f, a, b, minlevel=0, maxlevel=maxlevel)
  460. ref_x = xp.sort(f.x)
  461. for minlevel in range(0, maxlevel + 1):
  462. f.feval, f.calls, f.x = 0, 0, xp.asarray([])
  463. options = dict(minlevel=minlevel, maxlevel=maxlevel)
  464. res = _tanhsinh(f, a, b, **options)
  465. # Should be very close; all that has changed is the order of values
  466. xp_assert_close(res.integral, ref.integral, rtol=4e-16)
  467. # Difference in absolute errors << magnitude of integral
  468. xp_assert_close(res.error, ref.error, atol=4e-16 * ref.integral)
  469. assert res.nfev == f.feval == f.x.shape[0]
  470. assert f.calls == maxlevel - minlevel + 1 + 1 # 1 validation call
  471. assert res.status == ref.status
  472. xp_assert_equal(ref_x, xp.sort(f.x))
  473. def test_improper_integrals(self, xp):
  474. # Test handling of infinite limits of integration (mixed with finite limits)
  475. def f(x):
  476. x[xp.isinf(x)] = xp.nan
  477. return xp.exp(-x**2)
  478. a = xp.asarray([-xp.inf, 0, -xp.inf, xp.inf, -20, -xp.inf, -20])
  479. b = xp.asarray([xp.inf, xp.inf, 0, -xp.inf, 20, 20, xp.inf])
  480. ref = math.sqrt(math.pi)
  481. ref = xp.asarray([ref, ref/2, ref/2, -ref, ref, ref, ref])
  482. res = _tanhsinh(f, a, b)
  483. xp_assert_close(res.integral, ref)
  484. @pytest.mark.parametrize("limits", ((0, 3), ([-math.inf, 0], [3, 3])))
  485. @pytest.mark.parametrize("dtype", ('float32', 'float64'))
  486. def test_dtype(self, limits, dtype, xp):
  487. # Test that dtypes are preserved
  488. dtype = getattr(xp, dtype)
  489. a, b = xp.asarray(limits, dtype=dtype)
  490. def f(x):
  491. assert x.dtype == dtype
  492. return xp.exp(x)
  493. rtol = 1e-12 if dtype == xp.float64 else 1e-5
  494. res = _tanhsinh(f, a, b, rtol=rtol)
  495. assert res.integral.dtype == dtype
  496. assert res.error.dtype == dtype
  497. assert xp.all(res.success)
  498. xp_assert_close(res.integral, xp.exp(b)-xp.exp(a))
  499. def test_maxiter_callback(self, xp):
  500. # Test behavior of `maxiter` parameter and `callback` interface
  501. a, b = xp.asarray(-xp.inf), xp.asarray(xp.inf)
  502. def f(x):
  503. return xp.exp(-x*x)
  504. minlevel, maxlevel = 0, 2
  505. maxiter = maxlevel - minlevel + 1
  506. kwargs = dict(minlevel=minlevel, maxlevel=maxlevel, rtol=1e-15)
  507. res = _tanhsinh(f, a, b, **kwargs)
  508. assert not res.success
  509. assert res.maxlevel == maxlevel
  510. def callback(res):
  511. callback.iter += 1
  512. callback.res = res
  513. assert hasattr(res, 'integral')
  514. assert res.status == 1
  515. if callback.iter == maxiter:
  516. raise StopIteration
  517. callback.iter = -1 # callback called once before first iteration
  518. callback.res = None
  519. del kwargs['maxlevel']
  520. res2 = _tanhsinh(f, a, b, **kwargs, callback=callback)
  521. # terminating with callback is identical to terminating due to maxiter
  522. # (except for `status`)
  523. for key in res.keys():
  524. if key == 'status':
  525. assert res[key] == -2
  526. assert res2[key] == -4
  527. else:
  528. assert res2[key] == callback.res[key] == res[key]
  529. def test_jumpstart(self, xp):
  530. # The intermediate results at each level i should be the same as the
  531. # final results when jumpstarting at level i; i.e. minlevel=maxlevel=i
  532. a = xp.asarray(-xp.inf, dtype=xp.float64)
  533. b = xp.asarray(xp.inf, dtype=xp.float64)
  534. def f(x):
  535. return xp.exp(-x*x)
  536. def callback(res):
  537. callback.integrals.append(xp_copy(res.integral)[()])
  538. callback.errors.append(xp_copy(res.error)[()])
  539. callback.integrals = []
  540. callback.errors = []
  541. maxlevel = 4
  542. _tanhsinh(f, a, b, minlevel=0, maxlevel=maxlevel, callback=callback)
  543. for i in range(maxlevel + 1):
  544. res = _tanhsinh(f, a, b, minlevel=i, maxlevel=i)
  545. xp_assert_close(callback.integrals[1+i], res.integral, rtol=1e-15)
  546. xp_assert_close(callback.errors[1+i], res.error, rtol=1e-15, atol=1e-16)
  547. def test_special_cases(self, xp):
  548. # Test edge cases and other special cases
  549. a, b = xp.asarray(0), xp.asarray(1)
  550. def f(x):
  551. assert xp.isdtype(x.dtype, "real floating")
  552. return x
  553. res = _tanhsinh(f, a, b)
  554. assert res.success
  555. xp_assert_close(res.integral, xp.asarray(0.5))
  556. # Test levels 0 and 1; error is NaN
  557. res = _tanhsinh(f, a, b, maxlevel=0)
  558. assert res.integral > 0
  559. xp_assert_equal(res.error, xp.asarray(xp.nan))
  560. res = _tanhsinh(f, a, b, maxlevel=1)
  561. assert res.integral > 0
  562. xp_assert_equal(res.error, xp.asarray(xp.nan))
  563. # Test equal left and right integration limits
  564. res = _tanhsinh(f, b, b)
  565. assert res.success
  566. assert res.maxlevel == -1
  567. xp_assert_close(res.integral, xp.asarray(0.))
  568. # Test scalar `args` (not in tuple)
  569. def f(x, c):
  570. return x**c
  571. res = _tanhsinh(f, a, b, args=29)
  572. xp_assert_close(res.integral, xp.asarray(1/30))
  573. # Test NaNs
  574. a = xp.asarray([xp.nan, 0, 0, 0])
  575. b = xp.asarray([1, xp.nan, 1, 1])
  576. c = xp.asarray([1, 1, xp.nan, 1])
  577. res = _tanhsinh(f, a, b, args=(c,))
  578. xp_assert_close(res.integral, xp.asarray([xp.nan, xp.nan, xp.nan, 0.5]))
  579. xp_assert_equal(res.error[:3], xp.full((3,), xp.nan))
  580. xp_assert_equal(res.status, xp.asarray([-3, -3, -3, 0], dtype=xp.int32))
  581. xp_assert_equal(res.success, xp.asarray([False, False, False, True]))
  582. xp_assert_equal(res.nfev[:3], xp.full((3,), 1, dtype=xp.int32))
  583. # Test complex integral followed by real integral
  584. # Previously, h0 was of the result dtype. If the `dtype` were complex,
  585. # this could lead to complex cached abscissae/weights. If these get
  586. # cast to real dtype for a subsequent real integral, we would get a
  587. # ComplexWarning. Check that this is avoided.
  588. _pair_cache.xjc = xp.empty(0)
  589. _pair_cache.wj = xp.empty(0)
  590. _pair_cache.indices = [0]
  591. _pair_cache.h0 = None
  592. a, b = xp.asarray(0), xp.asarray(1)
  593. res = _tanhsinh(lambda x: xp.asarray(x*1j), a, b)
  594. xp_assert_close(res.integral, xp.asarray(0.5*1j))
  595. res = _tanhsinh(lambda x: x, a, b)
  596. xp_assert_close(res.integral, xp.asarray(0.5))
  597. # Test zero-size
  598. shape = (0, 3)
  599. res = _tanhsinh(lambda x: x, xp.asarray(0), xp.zeros(shape))
  600. attrs = ['integral', 'error', 'success', 'status', 'nfev', 'maxlevel']
  601. for attr in attrs:
  602. assert res[attr].shape == shape
  603. @pytest.mark.skip_xp_backends(np_only=True)
  604. def test_compress_nodes_weights_gh21496(self, xp):
  605. # See discussion in:
  606. # https://github.com/scipy/scipy/pull/21496#discussion_r1878681049
  607. # This would cause "ValueError: attempt to get argmax of an empty sequence"
  608. # Check that this has been resolved.
  609. x = np.full(65, 3)
  610. x[-1] = 1000
  611. _tanhsinh(np.sin, 1, x)
  612. def test_gh_22681_finite_error(self, xp):
  613. # gh-22681 noted a case in which the error was NaN on some platforms;
  614. # check that this does in fact fail in CI.
  615. c1 = complex(12, -10)
  616. c2 = complex(12, 39)
  617. def f(t):
  618. return xp.sin(c1 * (1 - t) + c2 * t)
  619. a, b = xp.asarray(0., dtype=xp.float64), xp.asarray(1., dtype=xp.float64)
  620. ref = _tanhsinh(f, a, b, atol=0, rtol=0, maxlevel=10)
  621. assert xp.isfinite(ref.error)
  622. # Previously, tanhsinh would not detect convergence
  623. res = _tanhsinh(f, a, b, rtol=1e-14)
  624. assert res.success
  625. assert res.maxlevel < 5
  626. xp_assert_close(res.integral, ref.integral, rtol=1e-15)
  627. @make_xp_test_case(nsum)
  628. class TestNSum:
  629. rng = np.random.default_rng(5895448232066142650)
  630. p = rng.uniform(1, 10, size=10).tolist()
  631. def f1(self, k):
  632. # Integers are never passed to `f1`; if they were, we'd get
  633. # integer to negative integer power error
  634. return k**(-2)
  635. f1.ref = np.pi**2/6
  636. f1.a = 1
  637. f1.b = np.inf
  638. f1.args = tuple()
  639. def f2(self, k, p):
  640. return 1 / k**p
  641. f2.ref = special.zeta(p, 1)
  642. f2.a = 1.
  643. f2.b = np.inf
  644. f2.args = (p,)
  645. def f3(self, k, p):
  646. return 1 / k**p
  647. f3.a = 1
  648. f3.b = rng.integers(5, 15, size=(3, 1))
  649. f3.ref = _gen_harmonic(f3.b, p)
  650. f3.args = (p,)
  651. def test_input_validation(self, xp):
  652. f = self.f1
  653. a, b = xp.asarray(f.a), xp.asarray(f.b)
  654. message = '`f` must be callable.'
  655. with pytest.raises(ValueError, match=message):
  656. nsum(42, a, b)
  657. message = '...must be True or False.'
  658. with pytest.raises(ValueError, match=message):
  659. nsum(f, a, b, log=2)
  660. message = '...must be real numbers.'
  661. with pytest.raises(ValueError, match=message):
  662. nsum(f, xp.asarray(1+1j), b)
  663. with pytest.raises(ValueError, match=message):
  664. nsum(f, a, xp.asarray(1+1j))
  665. with pytest.raises(ValueError, match=message):
  666. nsum(f, a, b, step=xp.asarray(1+1j))
  667. with pytest.raises(ValueError, match=message):
  668. nsum(f, a, b, tolerances=dict(atol='ekki'))
  669. with pytest.raises(ValueError, match=message):
  670. nsum(f, a, b, tolerances=dict(rtol=pytest))
  671. with (np.errstate(all='ignore')):
  672. res = nsum(f, xp.asarray([np.nan, np.inf]), xp.asarray(1.))
  673. assert (res.status[0] == -1) and not res.success[0]
  674. assert xp.isnan(res.sum[0]) and xp.isnan(res.error[0])
  675. assert (res.status[1] == 0) and res.success[1]
  676. assert res.sum[1] == res.error[1]
  677. assert xp.all(res.nfev[0] == 1)
  678. res = nsum(f, xp.asarray(10.), xp.asarray([np.nan, 1]))
  679. assert (res.status[0] == -1) and not res.success[0]
  680. assert xp.isnan(res.sum[0]) and xp.isnan(res.error[0])
  681. assert (res.status[1] == 0) and res.success[1]
  682. assert res.sum[1] == res.error[1]
  683. assert xp.all(res.nfev[0] == 1)
  684. res = nsum(f, xp.asarray(1.), xp.asarray(10.),
  685. step=xp.asarray([xp.nan, -xp.inf, xp.inf, -1, 0]))
  686. assert xp.all((res.status == -1) & xp.isnan(res.sum)
  687. & xp.isnan(res.error) & ~res.success & res.nfev == 1)
  688. message = '...must be non-negative and finite.'
  689. with pytest.raises(ValueError, match=message):
  690. nsum(f, a, b, tolerances=dict(rtol=-1))
  691. with pytest.raises(ValueError, match=message):
  692. nsum(f, a, b, tolerances=dict(atol=np.inf))
  693. message = '...may not be positive infinity.'
  694. with pytest.raises(ValueError, match=message):
  695. nsum(f, a, b, tolerances=dict(rtol=np.inf), log=True)
  696. with pytest.raises(ValueError, match=message):
  697. nsum(f, a, b, tolerances=dict(atol=np.inf), log=True)
  698. message = '...must be a non-negative integer.'
  699. with pytest.raises(ValueError, match=message):
  700. nsum(f, a, b, maxterms=3.5)
  701. with pytest.raises(ValueError, match=message):
  702. nsum(f, a, b, maxterms=-2)
  703. @pytest.mark.parametrize('f_number', range(1, 4))
  704. def test_basic(self, f_number, xp):
  705. dtype = xp.asarray(1.).dtype
  706. f = getattr(self, f"f{f_number}")
  707. a, b = xp.asarray(f.a), xp.asarray(f.b),
  708. args = tuple(xp.asarray(arg) for arg in f.args)
  709. ref = xp.asarray(f.ref, dtype=dtype)
  710. res = nsum(f, a, b, args=args)
  711. xp_assert_close(res.sum, ref)
  712. xp_assert_equal(res.status, xp.zeros(ref.shape, dtype=xp.int32))
  713. xp_assert_equal(res.success, xp.ones(ref.shape, dtype=xp.bool))
  714. with np.errstate(divide='ignore'):
  715. logres = nsum(lambda *args: xp.log(f(*args)),
  716. a, b, log=True, args=args)
  717. xp_assert_close(xp.exp(logres.sum), res.sum)
  718. xp_assert_close(xp.exp(logres.error), res.error, atol=1e-15)
  719. xp_assert_equal(logres.status, res.status)
  720. xp_assert_equal(logres.success, res.success)
  721. @pytest.mark.parametrize('maxterms', [0, 1, 10, 20, 100])
  722. def test_integral(self, maxterms, xp):
  723. # test precise behavior of integral approximation
  724. f = self.f1
  725. def logf(x):
  726. return -2*xp.log(x)
  727. def F(x):
  728. return -1 / x
  729. a = xp.asarray([1, 5], dtype=xp.float64)[:, xp.newaxis]
  730. b = xp.asarray([20, 100, xp.inf], dtype=xp.float64)[:, xp.newaxis, xp.newaxis]
  731. step = xp.asarray([0.5, 1, 2], dtype=xp.float64).reshape((-1, 1, 1, 1))
  732. nsteps = xp.floor((b - a)/step)
  733. b_original = b
  734. b = a + nsteps*step
  735. k = a + maxterms*step
  736. # partial sum
  737. direct = xp.sum(f(a + xp.arange(maxterms)*step), axis=-1, keepdims=True)
  738. integral = (F(b) - F(k))/step # integral approximation of remainder
  739. low = direct + integral + f(b) # theoretical lower bound
  740. high = direct + integral + f(k) # theoretical upper bound
  741. ref_sum = (low + high)/2 # nsum uses average of the two
  742. ref_err = (high - low)/2 # error (assuming perfect quadrature)
  743. # correct reference values where number of terms < maxterms
  744. a, b, step = xp.broadcast_arrays(a, b, step)
  745. for i in np.ndindex(a.shape):
  746. ai, bi, stepi = float(a[i]), float(b[i]), float(step[i])
  747. if (bi - ai)/stepi + 1 <= maxterms:
  748. direct = xp.sum(f(xp.arange(ai, bi+stepi, stepi, dtype=xp.float64)))
  749. ref_sum[i] = direct
  750. ref_err[i] = direct * xp.finfo(direct.dtype).eps
  751. rtol = 1e-12
  752. res = nsum(f, a, b_original, step=step, maxterms=maxterms,
  753. tolerances=dict(rtol=rtol))
  754. xp_assert_close(res.sum, ref_sum, rtol=10*rtol)
  755. xp_assert_close(res.error, ref_err, rtol=100*rtol)
  756. i = ((b_original - a)/step + 1 <= maxterms)
  757. xp_assert_close(res.sum[i], ref_sum[i], rtol=1e-15)
  758. xp_assert_close(res.error[i], ref_err[i], rtol=1e-15)
  759. logres = nsum(logf, a, b_original, step=step, log=True,
  760. tolerances=dict(rtol=math.log(rtol)), maxterms=maxterms)
  761. xp_assert_close(xp.exp(logres.sum), res.sum)
  762. xp_assert_close(xp.exp(logres.error), res.error)
  763. @pytest.mark.parametrize('shape', [tuple(), (12,), (3, 4), (3, 2, 2)])
  764. def test_vectorization(self, shape, xp):
  765. # Test for correct functionality, output shapes, and dtypes for various
  766. # input shapes.
  767. rng = np.random.default_rng(82456839535679456794)
  768. a = rng.integers(1, 10, size=shape)
  769. # when the sum can be computed directly or `maxterms` is large enough
  770. # to meet `atol`, there are slight differences (for good reason)
  771. # between vectorized call and looping.
  772. b = np.inf
  773. p = rng.random(shape) + 1
  774. n = math.prod(shape)
  775. def f(x, p):
  776. f.feval += 1 if (x.size == n or x.ndim <= 1) else x.shape[-1]
  777. return 1 / x ** p
  778. f.feval = 0
  779. @np.vectorize
  780. def nsum_single(a, b, p, maxterms):
  781. return nsum(lambda x: 1 / x**p, a, b, maxterms=maxterms)
  782. res = nsum(f, xp.asarray(a), xp.asarray(b), maxterms=1000,
  783. args=(xp.asarray(p),))
  784. refs = nsum_single(a, b, p, maxterms=1000).ravel()
  785. attrs = ['sum', 'error', 'success', 'status', 'nfev']
  786. for attr in attrs:
  787. ref_attr = [xp.asarray(getattr(ref, attr)) for ref in refs]
  788. res_attr = getattr(res, attr)
  789. xp_assert_close(xp_ravel(res_attr), xp.asarray(ref_attr), rtol=1e-15)
  790. assert res_attr.shape == shape
  791. assert xp.isdtype(res.success.dtype, 'bool')
  792. assert xp.isdtype(res.status.dtype, 'integral')
  793. assert xp.isdtype(res.nfev.dtype, 'integral')
  794. if is_numpy(xp): # other libraries might have different number
  795. assert int(xp.max(res.nfev)) == f.feval
  796. def test_status(self, xp):
  797. f = self.f2
  798. p = [2, 2, 0.9, 1.1, 2, 2]
  799. a = xp.asarray([0, 0, 1, 1, 1, np.nan], dtype=xp.float64)
  800. b = xp.asarray([10, np.inf, np.inf, np.inf, np.inf, np.inf], dtype=xp.float64)
  801. ref = special.zeta(p, 1)
  802. p = xp.asarray(p, dtype=xp.float64)
  803. with np.errstate(divide='ignore'): # intentionally dividing by zero
  804. res = nsum(f, a, b, args=(p,))
  805. ref_success = xp.asarray([False, False, False, False, True, False])
  806. ref_status = xp.asarray([-3, -3, -2, -4, 0, -1], dtype=xp.int32)
  807. xp_assert_equal(res.success, ref_success)
  808. xp_assert_equal(res.status, ref_status)
  809. xp_assert_close(res.sum[res.success], xp.asarray(ref)[res.success])
  810. def test_nfev(self, xp):
  811. def f(x):
  812. f.nfev += xp_size(x)
  813. return 1 / x**2
  814. f.nfev = 0
  815. res = nsum(f, xp.asarray(1), xp.asarray(10))
  816. assert res.nfev == f.nfev
  817. f.nfev = 0
  818. res = nsum(f, xp.asarray(1), xp.asarray(xp.inf), tolerances=dict(atol=1e-6))
  819. assert res.nfev == f.nfev
  820. def test_inclusive(self, xp):
  821. # There was an edge case off-by one bug when `_direct` was called with
  822. # `inclusive=True`. Check that this is resolved.
  823. a = xp.asarray([1, 4])
  824. b = xp.asarray(xp.inf)
  825. res = nsum(lambda k: 1 / k ** 2, a, b,
  826. maxterms=500, tolerances=dict(atol=0.1))
  827. ref = nsum(lambda k: 1 / k ** 2, a, b)
  828. assert xp.all(res.sum > (ref.sum - res.error))
  829. assert xp.all(res.sum < (ref.sum + res.error))
  830. @pytest.mark.parametrize('log', [True, False])
  831. def test_infinite_bounds(self, log, xp):
  832. a = xp.asarray([1, -np.inf, -np.inf])
  833. b = xp.asarray([np.inf, -1, np.inf])
  834. c = xp.asarray([1, 2, 3])
  835. def f(x, a):
  836. return (xp.log(xp.tanh(a / 2)) - a*xp.abs(x) if log
  837. else xp.tanh(a/2) * xp.exp(-a*xp.abs(x)))
  838. res = nsum(f, a, b, args=(c,), log=log)
  839. ref = xp.asarray([stats.dlaplace.sf(0, 1), stats.dlaplace.sf(0, 2), 1])
  840. ref = xp.log(ref) if log else ref
  841. atol = (1e-10 if a.dtype==xp.float64 else 1e-5) if log else 0
  842. xp_assert_close(res.sum, xp.asarray(ref, dtype=a.dtype), atol=atol)
  843. # # Make sure the sign of `x` passed into `f` is correct.
  844. def f(x, c):
  845. return -3*xp.log(c*x) if log else 1 / (c*x)**3
  846. a = xp.asarray([1, -np.inf])
  847. b = xp.asarray([np.inf, -1])
  848. arg = xp.asarray([1, -1])
  849. res = nsum(f, a, b, args=(arg,), log=log)
  850. ref = np.log(special.zeta(3)) if log else special.zeta(3)
  851. xp_assert_close(res.sum, xp.full(a.shape, ref, dtype=a.dtype))
  852. def test_decreasing_check(self, xp):
  853. # Test accuracy when we start sum on an uphill slope.
  854. # Without the decreasing check, the terms would look small enough to
  855. # use the integral approximation. Because the function is not decreasing,
  856. # the error is not bounded by the magnitude of the last term of the
  857. # partial sum. In this case, the error would be ~1e-4, causing the test
  858. # to fail.
  859. def f(x):
  860. return xp.exp(-x ** 2)
  861. a, b = xp.asarray(-25, dtype=xp.float64), xp.asarray(np.inf, dtype=xp.float64)
  862. res = nsum(f, a, b)
  863. # Reference computed with mpmath:
  864. # from mpmath import mp
  865. # mp.dps = 50
  866. # def fmp(x): return mp.exp(-x**2)
  867. # ref = mp.nsum(fmp, (-25, 0)) + mp.nsum(fmp, (1, mp.inf))
  868. ref = xp.asarray(1.772637204826652, dtype=xp.float64)
  869. xp_assert_close(res.sum, ref, rtol=1e-15)
  870. def test_special_case(self, xp):
  871. # test equal lower/upper limit
  872. f = self.f1
  873. a = b = xp.asarray(2)
  874. res = nsum(f, a, b)
  875. xp_assert_equal(res.sum, xp.asarray(f(2)))
  876. # Test scalar `args` (not in tuple)
  877. res = nsum(self.f2, xp.asarray(1), xp.asarray(np.inf), args=xp.asarray(2))
  878. xp_assert_close(res.sum, xp.asarray(self.f1.ref)) # f1.ref is correct w/ args=2
  879. # Test 0 size input
  880. a = xp.empty((3, 1, 1)) # arbitrary broadcastable shapes
  881. b = xp.empty((0, 1)) # could use Hypothesis
  882. p = xp.empty(4) # but it's overkill
  883. shape = np.broadcast_shapes(a.shape, b.shape, p.shape)
  884. res = nsum(self.f2, a, b, args=(p,))
  885. assert res.sum.shape == shape
  886. assert res.status.shape == shape
  887. assert res.nfev.shape == shape
  888. # Test maxterms=0
  889. def f(x):
  890. with np.errstate(divide='ignore'):
  891. return 1 / x
  892. res = nsum(f, xp.asarray(0), xp.asarray(10), maxterms=0)
  893. assert xp.isinf(res.sum)
  894. assert xp.isinf(res.error)
  895. assert res.status == -2
  896. res = nsum(f, xp.asarray(0), xp.asarray(10), maxterms=1)
  897. assert xp.isnan(res.sum)
  898. assert xp.isnan(res.error)
  899. assert res.status == -3
  900. # Test NaNs
  901. # should skip both direct and integral methods if there are NaNs
  902. a = xp.asarray([xp.nan, 1, 1, 1])
  903. b = xp.asarray([xp.inf, xp.nan, xp.inf, xp.inf])
  904. p = xp.asarray([2, 2, xp.nan, 2])
  905. res = nsum(self.f2, a, b, args=(p,))
  906. xp_assert_close(res.sum, xp.asarray([xp.nan, xp.nan, xp.nan, self.f1.ref]))
  907. xp_assert_close(res.error[:3], xp.full((3,), xp.nan))
  908. xp_assert_equal(res.status, xp.asarray([-1, -1, -3, 0], dtype=xp.int32))
  909. xp_assert_equal(res.success, xp.asarray([False, False, False, True]))
  910. # Ideally res.nfev[2] would be 1, but `tanhsinh` has some function evals
  911. xp_assert_equal(res.nfev[:2], xp.full((2,), 1, dtype=xp.int32))
  912. @pytest.mark.parametrize('dtype', ['float32', 'float64'])
  913. def test_dtype(self, dtype, xp):
  914. dtype = getattr(xp, dtype)
  915. def f(k):
  916. assert k.dtype == dtype
  917. return 1 / k ** xp.asarray(2, dtype=dtype)
  918. a = xp.asarray(1, dtype=dtype)
  919. b = xp.asarray([10, xp.inf], dtype=dtype)
  920. res = nsum(f, a, b)
  921. assert res.sum.dtype == dtype
  922. assert res.error.dtype == dtype
  923. rtol = 1e-12 if dtype == xp.float64 else 1e-6
  924. ref = [_gen_harmonic(10, 2), special.zeta(2, 1)]
  925. xp_assert_close(res.sum, xp.asarray(ref, dtype=dtype), rtol=rtol)
  926. @pytest.mark.parametrize('case', [(10, 100), (100, 10)])
  927. def test_nondivisible_interval(self, case, xp):
  928. # When the limits of the sum are such that (b - a)/step
  929. # is not exactly integral, check that only floor((b - a)/step)
  930. # terms are included.
  931. n, maxterms = case
  932. def f(k):
  933. return 1 / k ** 2
  934. a = np.e
  935. step = 1 / 3
  936. b0 = a + n * step
  937. i = np.arange(-2, 3)
  938. b = b0 + i * np.spacing(b0)
  939. ns = np.floor((b - a) / step)
  940. assert len(set(ns)) == 2
  941. a, b = xp.asarray(a, dtype=xp.float64), xp.asarray(b, dtype=xp.float64)
  942. step, ns = xp.asarray(step, dtype=xp.float64), xp.asarray(ns, dtype=xp.float64)
  943. res = nsum(f, a, b, step=step, maxterms=maxterms)
  944. xp_assert_equal(xp.diff(ns) > 0, xp.diff(res.sum) > 0)
  945. xp_assert_close(res.sum[-1], res.sum[0] + f(b0))
  946. @pytest.mark.skip_xp_backends(np_only=True, reason='Needs beta function.')
  947. def test_logser_kurtosis_gh20648(self, xp):
  948. # Some functions return NaN at infinity rather than 0 like they should.
  949. # Check that this is accounted for.
  950. ref = stats.yulesimon.moment(4, 5)
  951. def f(x):
  952. return stats.yulesimon._pmf(x, 5) * x**4
  953. with np.errstate(invalid='ignore'):
  954. assert np.isnan(f(np.inf))
  955. res = nsum(f, 1, np.inf)
  956. assert_allclose(res.sum, ref)