test__util.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. from multiprocessing import Pool
  2. from multiprocessing.pool import Pool as PWL
  3. import re
  4. import math
  5. import functools
  6. from fractions import Fraction
  7. import numpy as np
  8. from numpy.testing import assert_equal, assert_
  9. import pytest
  10. from pytest import raises as assert_raises
  11. from scipy.conftest import skip_xp_invalid_arg
  12. from scipy._lib._array_api import xp_assert_equal
  13. from scipy._lib._util import (_aligned_zeros, check_random_state, MapWrapper,
  14. getfullargspec_no_self, FullArgSpec,
  15. rng_integers, _validate_int, _rename_parameter,
  16. _contains_nan, _rng_html_rewrite, _workers_wrapper)
  17. import scipy._lib.array_api_extra as xpx
  18. from scipy._lib.array_api_extra.testing import lazy_xp_function
  19. from scipy import cluster, interpolate, linalg, optimize, sparse, spatial, stats
  20. lazy_xp_function(_contains_nan)
  21. @pytest.mark.slow
  22. def test__aligned_zeros():
  23. niter = 10
  24. def check(shape, dtype, order, align):
  25. err_msg = repr((shape, dtype, order, align))
  26. x = _aligned_zeros(shape, dtype, order, align=align)
  27. if align is None:
  28. align = np.dtype(dtype).alignment
  29. assert_equal(x.__array_interface__['data'][0] % align, 0)
  30. if hasattr(shape, '__len__'):
  31. assert_equal(x.shape, shape, err_msg)
  32. else:
  33. assert_equal(x.shape, (shape,), err_msg)
  34. assert_equal(x.dtype, dtype)
  35. if order == "C":
  36. assert_(x.flags.c_contiguous, err_msg)
  37. elif order == "F":
  38. if x.size > 0:
  39. # Size-0 arrays get invalid flags on NumPy 1.5
  40. assert_(x.flags.f_contiguous, err_msg)
  41. elif order is None:
  42. assert_(x.flags.c_contiguous, err_msg)
  43. else:
  44. raise ValueError()
  45. # try various alignments
  46. for align in [1, 2, 3, 4, 8, 16, 32, 64, None]:
  47. for n in [0, 1, 3, 11]:
  48. for order in ["C", "F", None]:
  49. for dtype in [np.uint8, np.float64]:
  50. for shape in [n, (1, 2, 3, n)]:
  51. for j in range(niter):
  52. check(shape, dtype, order, align)
  53. def test_check_random_state():
  54. # If seed is None, return the RandomState singleton used by np.random.
  55. # If seed is an int, return a new RandomState instance seeded with seed.
  56. # If seed is already a RandomState instance, return it.
  57. # Otherwise raise ValueError.
  58. rsi = check_random_state(1)
  59. assert_equal(type(rsi), np.random.RandomState)
  60. rsi = check_random_state(rsi)
  61. assert_equal(type(rsi), np.random.RandomState)
  62. rsi = check_random_state(None)
  63. assert_equal(type(rsi), np.random.RandomState)
  64. assert_raises(ValueError, check_random_state, 'a')
  65. rg = np.random.Generator(np.random.PCG64())
  66. rsi = check_random_state(rg)
  67. assert_equal(type(rsi), np.random.Generator)
  68. def test_getfullargspec_no_self():
  69. p = MapWrapper(1)
  70. argspec = getfullargspec_no_self(p.__init__)
  71. assert_equal(argspec, FullArgSpec(['pool'], None, None, (1,), [],
  72. None, {}))
  73. argspec = getfullargspec_no_self(p.__call__)
  74. assert_equal(argspec, FullArgSpec(['func', 'iterable'], None, None, None,
  75. [], None, {}))
  76. class _rv_generic:
  77. def _rvs(self, a, b=2, c=3, *args, size=None, **kwargs):
  78. return None
  79. rv_obj = _rv_generic()
  80. argspec = getfullargspec_no_self(rv_obj._rvs)
  81. assert_equal(argspec, FullArgSpec(['a', 'b', 'c'], 'args', 'kwargs',
  82. (2, 3), ['size'], {'size': None}, {}))
  83. def test_mapwrapper_serial():
  84. in_arg = np.arange(10.)
  85. out_arg = np.sin(in_arg)
  86. p = MapWrapper(1)
  87. assert_(p._mapfunc is map)
  88. assert_(p.pool is None)
  89. assert_(p._own_pool is False)
  90. out = list(p(np.sin, in_arg))
  91. assert_equal(out, out_arg)
  92. with assert_raises(RuntimeError):
  93. p = MapWrapper(0)
  94. def test_pool():
  95. with Pool(2) as p:
  96. p.map(math.sin, [1, 2, 3, 4])
  97. def test_mapwrapper_parallel():
  98. in_arg = np.arange(10.)
  99. out_arg = np.sin(in_arg)
  100. with MapWrapper(2) as p:
  101. out = p(np.sin, in_arg)
  102. assert_equal(list(out), out_arg)
  103. assert_(p._own_pool is True)
  104. assert_(isinstance(p.pool, PWL))
  105. assert_(p._mapfunc is not None)
  106. # the context manager should've closed the internal pool
  107. # check that it has by asking it to calculate again.
  108. with assert_raises(Exception) as excinfo:
  109. p(np.sin, in_arg)
  110. assert_(excinfo.type is ValueError)
  111. # can also set a PoolWrapper up with a map-like callable instance
  112. with Pool(2) as p:
  113. q = MapWrapper(p.map)
  114. assert_(q._own_pool is False)
  115. q.close()
  116. # closing the PoolWrapper shouldn't close the internal pool
  117. # because it didn't create it
  118. out = p.map(np.sin, in_arg)
  119. assert_equal(list(out), out_arg)
  120. @_workers_wrapper
  121. def user_of_workers(x, b=1, workers=None):
  122. assert workers is not None
  123. assert isinstance(workers, MapWrapper)
  124. return np.array(list(workers(np.sin, x * b)))
  125. def test__workers_wrapper():
  126. arr = np.linspace(0, np.pi)
  127. req = np.sin(arr * 2.0)
  128. with Pool(2) as p:
  129. v = user_of_workers(arr, workers=p.map, b=2)
  130. assert_equal(v, req)
  131. v = user_of_workers(arr, workers=None, b=2)
  132. assert_equal(v, req)
  133. v = user_of_workers(arr, workers=2, b=2)
  134. assert_equal(v, req)
  135. # assess if decorator works with partial functions
  136. part_f = functools.partial(user_of_workers, b=2)
  137. assert_equal(part_f(arr), req)
  138. with Pool(2) as p:
  139. part_f = functools.partial(user_of_workers, b=2, workers=p.map)
  140. assert_equal(part_f(arr), req)
  141. def test_rng_integers():
  142. rng = np.random.RandomState()
  143. # test that numbers are inclusive of high point
  144. arr = rng_integers(rng, low=2, high=5, size=100, endpoint=True)
  145. assert np.max(arr) == 5
  146. assert np.min(arr) == 2
  147. assert arr.shape == (100, )
  148. # test that numbers are inclusive of high point
  149. arr = rng_integers(rng, low=5, size=100, endpoint=True)
  150. assert np.max(arr) == 5
  151. assert np.min(arr) == 0
  152. assert arr.shape == (100, )
  153. # test that numbers are exclusive of high point
  154. arr = rng_integers(rng, low=2, high=5, size=100, endpoint=False)
  155. assert np.max(arr) == 4
  156. assert np.min(arr) == 2
  157. assert arr.shape == (100, )
  158. # test that numbers are exclusive of high point
  159. arr = rng_integers(rng, low=5, size=100, endpoint=False)
  160. assert np.max(arr) == 4
  161. assert np.min(arr) == 0
  162. assert arr.shape == (100, )
  163. # now try with np.random.Generator
  164. try:
  165. rng = np.random.default_rng()
  166. except AttributeError:
  167. return
  168. # test that numbers are inclusive of high point
  169. arr = rng_integers(rng, low=2, high=5, size=100, endpoint=True)
  170. assert np.max(arr) == 5
  171. assert np.min(arr) == 2
  172. assert arr.shape == (100, )
  173. # test that numbers are inclusive of high point
  174. arr = rng_integers(rng, low=5, size=100, endpoint=True)
  175. assert np.max(arr) == 5
  176. assert np.min(arr) == 0
  177. assert arr.shape == (100, )
  178. # test that numbers are exclusive of high point
  179. arr = rng_integers(rng, low=2, high=5, size=100, endpoint=False)
  180. assert np.max(arr) == 4
  181. assert np.min(arr) == 2
  182. assert arr.shape == (100, )
  183. # test that numbers are exclusive of high point
  184. arr = rng_integers(rng, low=5, size=100, endpoint=False)
  185. assert np.max(arr) == 4
  186. assert np.min(arr) == 0
  187. assert arr.shape == (100, )
  188. class TestValidateInt:
  189. @pytest.mark.parametrize('n', [4, np.uint8(4), np.int16(4), np.array(4)])
  190. def test_validate_int(self, n):
  191. n = _validate_int(n, 'n')
  192. assert n == 4
  193. @pytest.mark.parametrize('n', [4.0, np.array([4]), Fraction(4, 1)])
  194. def test_validate_int_bad(self, n):
  195. with pytest.raises(TypeError, match='n must be an integer'):
  196. _validate_int(n, 'n')
  197. def test_validate_int_below_min(self):
  198. with pytest.raises(ValueError, match='n must be an integer not '
  199. 'less than 0'):
  200. _validate_int(-1, 'n', 0)
  201. class TestRenameParameter:
  202. # check that wrapper `_rename_parameter` for backward-compatible
  203. # keyword renaming works correctly
  204. # Example method/function that still accepts keyword `old`
  205. @_rename_parameter("old", "new")
  206. def old_keyword_still_accepted(self, new):
  207. return new
  208. # Example method/function for which keyword `old` is deprecated
  209. @_rename_parameter("old", "new", dep_version="1.9.0")
  210. def old_keyword_deprecated(self, new):
  211. return new
  212. def test_old_keyword_still_accepted(self):
  213. # positional argument and both keyword work identically
  214. res1 = self.old_keyword_still_accepted(10)
  215. res2 = self.old_keyword_still_accepted(new=10)
  216. res3 = self.old_keyword_still_accepted(old=10)
  217. assert res1 == res2 == res3 == 10
  218. # unexpected keyword raises an error
  219. message = re.escape("old_keyword_still_accepted() got an unexpected")
  220. with pytest.raises(TypeError, match=message):
  221. self.old_keyword_still_accepted(unexpected=10)
  222. # multiple values for the same parameter raises an error
  223. message = re.escape("old_keyword_still_accepted() got multiple")
  224. with pytest.raises(TypeError, match=message):
  225. self.old_keyword_still_accepted(10, new=10)
  226. with pytest.raises(TypeError, match=message):
  227. self.old_keyword_still_accepted(10, old=10)
  228. with pytest.raises(TypeError, match=message):
  229. self.old_keyword_still_accepted(new=10, old=10)
  230. @pytest.fixture
  231. def kwarg_lock(self):
  232. from threading import Lock
  233. return Lock()
  234. def test_old_keyword_deprecated(self, kwarg_lock):
  235. # positional argument and both keyword work identically,
  236. # but use of old keyword results in DeprecationWarning
  237. dep_msg = "Use of keyword argument `old` is deprecated"
  238. res1 = self.old_keyword_deprecated(10)
  239. res2 = self.old_keyword_deprecated(new=10)
  240. # pytest warning filter is not thread-safe, enforce serialization
  241. with kwarg_lock:
  242. with pytest.warns(DeprecationWarning, match=dep_msg):
  243. res3 = self.old_keyword_deprecated(old=10)
  244. assert res1 == res2 == res3 == 10
  245. # unexpected keyword raises an error
  246. message = re.escape("old_keyword_deprecated() got an unexpected")
  247. with pytest.raises(TypeError, match=message):
  248. self.old_keyword_deprecated(unexpected=10)
  249. # multiple values for the same parameter raises an error and,
  250. # if old keyword is used, results in DeprecationWarning
  251. message = re.escape("old_keyword_deprecated() got multiple")
  252. with pytest.raises(TypeError, match=message):
  253. self.old_keyword_deprecated(10, new=10)
  254. with kwarg_lock:
  255. with pytest.raises(TypeError, match=message), \
  256. pytest.warns(DeprecationWarning, match=dep_msg):
  257. # breakpoint()
  258. self.old_keyword_deprecated(10, old=10)
  259. with kwarg_lock:
  260. with pytest.raises(TypeError, match=message), \
  261. pytest.warns(DeprecationWarning, match=dep_msg):
  262. self.old_keyword_deprecated(new=10, old=10)
  263. class TestContainsNaN:
  264. def test_policy(self):
  265. data = np.array([1, 2, 3, np.nan])
  266. assert _contains_nan(data) # default policy is "propagate"
  267. assert _contains_nan(data, nan_policy="propagate")
  268. assert _contains_nan(data, nan_policy="omit")
  269. assert not _contains_nan(data[:3])
  270. assert not _contains_nan(data[:3], nan_policy="propagate")
  271. assert not _contains_nan(data[:3], nan_policy="omit")
  272. with pytest.raises(ValueError, match="The input contains nan values"):
  273. _contains_nan(data, nan_policy="raise")
  274. assert not _contains_nan(data[:3], nan_policy="raise")
  275. with pytest.raises(ValueError, match="nan_policy must be one of"):
  276. _contains_nan(data, nan_policy="nan")
  277. def test_contains_nan(self):
  278. # Special case: empty array
  279. assert not _contains_nan(np.array([], dtype=float))
  280. # Integer arrays cannot contain NaN
  281. assert not _contains_nan(np.array([1, 2, 3]))
  282. assert not _contains_nan(np.array([[1, 2], [3, 4]]))
  283. assert not _contains_nan(np.array([1., 2., 3.]))
  284. assert not _contains_nan(np.array([1., 2.j, 3.]))
  285. assert _contains_nan(np.array([1., 2.j, np.nan]))
  286. assert _contains_nan(np.array([1., 2., np.nan]))
  287. assert _contains_nan(np.array([np.nan, 2., np.nan]))
  288. assert not _contains_nan(np.array([[1., 2.], [3., 4.]]))
  289. assert _contains_nan(np.array([[1., 2.], [3., np.nan]]))
  290. @skip_xp_invalid_arg
  291. def test_contains_nan_with_strings(self):
  292. data1 = np.array([1, 2, "3", np.nan]) # converted to string "nan"
  293. assert not _contains_nan(data1)
  294. data2 = np.array([1, 2, "3", np.nan], dtype='object')
  295. assert _contains_nan(data2)
  296. data3 = np.array([["1", 2], [3, np.nan]]) # converted to string "nan"
  297. assert not _contains_nan(data3)
  298. data4 = np.array([["1", 2], [3, np.nan]], dtype='object')
  299. assert _contains_nan(data4)
  300. @pytest.mark.skip_xp_backends(eager_only=True,
  301. reason="lazy backends tested separately")
  302. @pytest.mark.parametrize("nan_policy", ['propagate', 'omit', 'raise'])
  303. def test_array_api(self, xp, nan_policy):
  304. rng = np.random.default_rng(932347235892482)
  305. x0 = rng.random(size=(2, 3, 4))
  306. x = xp.asarray(x0)
  307. assert not _contains_nan(x, nan_policy)
  308. x = xpx.at(x)[1, 2, 1].set(xp.nan)
  309. if nan_policy == 'raise':
  310. with pytest.raises(ValueError, match="The input contains nan values"):
  311. _contains_nan(x, nan_policy)
  312. elif nan_policy == 'omit':
  313. assert _contains_nan(x, nan_policy, xp_omit_okay=True)
  314. elif nan_policy == 'propagate':
  315. assert _contains_nan(x, nan_policy)
  316. @pytest.mark.skip_xp_backends("numpy", reason="lazy backends only")
  317. @pytest.mark.skip_xp_backends("cupy", reason="lazy backends only")
  318. @pytest.mark.skip_xp_backends("array_api_strict", reason="lazy backends only")
  319. @pytest.mark.skip_xp_backends("torch", reason="lazy backends only")
  320. def test_array_api_lazy(self, xp):
  321. rng = np.random.default_rng(932347235892482)
  322. x0 = rng.random(size=(2, 3, 4))
  323. x = xp.asarray(x0)
  324. xp_assert_equal(_contains_nan(x), xp.asarray(False))
  325. xp_assert_equal(_contains_nan(x, "propagate"), xp.asarray(False))
  326. xp_assert_equal(_contains_nan(x, "omit", xp_omit_okay=True), xp.asarray(False))
  327. # Lazy arrays don't support "omit" and "raise" policies
  328. match = "not supported for lazy arrays"
  329. with pytest.raises(TypeError, match=match):
  330. _contains_nan(x, "omit")
  331. with pytest.raises(TypeError, match=match):
  332. _contains_nan(x, "raise")
  333. x = xpx.at(x)[1, 2, 1].set(np.nan)
  334. xp_assert_equal(_contains_nan(x), xp.asarray(True))
  335. xp_assert_equal(_contains_nan(x, "propagate"), xp.asarray(True))
  336. xp_assert_equal(_contains_nan(x, "omit", xp_omit_okay=True), xp.asarray(True))
  337. with pytest.raises(TypeError, match=match):
  338. _contains_nan(x, "omit")
  339. with pytest.raises(TypeError, match=match):
  340. _contains_nan(x, "raise")
  341. def test__rng_html_rewrite():
  342. def mock_str():
  343. lines = [
  344. 'np.random.default_rng(8989843)',
  345. 'np.random.default_rng(seed)',
  346. 'np.random.default_rng(0x9a71b21474694f919882289dc1559ca)',
  347. ' bob ',
  348. ]
  349. return lines
  350. res = _rng_html_rewrite(mock_str)()
  351. ref = [
  352. 'np.random.default_rng()',
  353. 'np.random.default_rng(seed)',
  354. 'np.random.default_rng()',
  355. ' bob ',
  356. ]
  357. assert res == ref
  358. class TestTransitionToRNG:
  359. def kmeans(self, **kwargs):
  360. rng = np.random.default_rng(3458934594269824562)
  361. return cluster.vq.kmeans2(rng.random(size=(20, 3)), 3, **kwargs)
  362. def kmeans2(self, **kwargs):
  363. rng = np.random.default_rng(3458934594269824562)
  364. return cluster.vq.kmeans2(rng.random(size=(20, 3)), 3, **kwargs)
  365. def barycentric(self, **kwargs):
  366. rng = np.random.default_rng(3458934594269824562)
  367. x1, x2, y1 = rng.random((3, 10))
  368. f = interpolate.BarycentricInterpolator(x1, y1, **kwargs)
  369. return f(x2)
  370. def clarkson_woodruff_transform(self, **kwargs):
  371. rng = np.random.default_rng(3458934594269824562)
  372. return linalg.clarkson_woodruff_transform(rng.random((10, 10)), 3, **kwargs)
  373. def basinhopping(self, **kwargs):
  374. rng = np.random.default_rng(3458934594269824562)
  375. return optimize.basinhopping(optimize.rosen, rng.random(3), **kwargs).x
  376. def opt(self, fun, **kwargs):
  377. rng = np.random.default_rng(3458934594269824562)
  378. bounds = optimize.Bounds(-rng.random(3) * 10, rng.random(3) * 10)
  379. return fun(optimize.rosen, bounds, **kwargs).x
  380. def differential_evolution(self, **kwargs):
  381. return self.opt(optimize.differential_evolution, **kwargs)
  382. def dual_annealing(self, **kwargs):
  383. return self.opt(optimize.dual_annealing, **kwargs)
  384. def check_grad(self, **kwargs):
  385. rng = np.random.default_rng(3458934594269824562)
  386. x = rng.random(3)
  387. return optimize.check_grad(optimize.rosen, optimize.rosen_der, x,
  388. direction='random', **kwargs)
  389. def random_array(self, **kwargs):
  390. return sparse.random_array((10, 10), density=1.0, **kwargs).toarray()
  391. def random(self, **kwargs):
  392. return sparse.random(10, 10, density=1.0, **kwargs).toarray()
  393. def rand(self, **kwargs):
  394. return sparse.rand(10, 10, density=1.0, **kwargs).toarray()
  395. def svds(self, **kwargs):
  396. rng = np.random.default_rng(3458934594269824562)
  397. A = rng.random((10, 10))
  398. return sparse.linalg.svds(A, **kwargs)
  399. def random_rotation(self, **kwargs):
  400. return spatial.transform.Rotation.random(3, **kwargs).as_matrix()
  401. def goodness_of_fit(self, **kwargs):
  402. rng = np.random.default_rng(3458934594269824562)
  403. data = rng.random(100)
  404. return stats.goodness_of_fit(stats.laplace, data, **kwargs).pvalue
  405. def permutation_test(self, **kwargs):
  406. rng = np.random.default_rng(3458934594269824562)
  407. data = tuple(rng.random((2, 100)))
  408. def statistic(x, y, axis): return np.mean(x, axis=axis) - np.mean(y, axis=axis)
  409. return stats.permutation_test(data, statistic, **kwargs).pvalue
  410. def bootstrap(self, **kwargs):
  411. rng = np.random.default_rng(3458934594269824562)
  412. data = (rng.random(100),)
  413. return stats.bootstrap(data, np.mean, **kwargs).confidence_interval
  414. def dunnett(self, **kwargs):
  415. rng = np.random.default_rng(3458934594269824562)
  416. x, y, control = rng.random((3, 100))
  417. return stats.dunnett(x, y, control=control, **kwargs).pvalue
  418. def sobol_indices(self, **kwargs):
  419. def f_ishigami(x): return (np.sin(x[0]) + 7 * np.sin(x[1]) ** 2
  420. + 0.1 * (x[2] ** 4) * np.sin(x[0]))
  421. dists = [stats.uniform(loc=-np.pi, scale=2 * np.pi),
  422. stats.uniform(loc=-np.pi, scale=2 * np.pi),
  423. stats.uniform(loc=-np.pi, scale=2 * np.pi)]
  424. res = stats.sobol_indices(func=f_ishigami, n=1024, dists=dists, **kwargs)
  425. return res.first_order
  426. def qmc_engine(self, engine, **kwargs):
  427. qrng = engine(d=1, **kwargs)
  428. return qrng.random(4)
  429. def halton(self, **kwargs):
  430. return self.qmc_engine(stats.qmc.Halton, **kwargs)
  431. def sobol(self, **kwargs):
  432. return self.qmc_engine(stats.qmc.Sobol, **kwargs)
  433. def latin_hypercube(self, **kwargs):
  434. return self.qmc_engine(stats.qmc.LatinHypercube, **kwargs)
  435. def poisson_disk(self, **kwargs):
  436. return self.qmc_engine(stats.qmc.PoissonDisk, **kwargs)
  437. def multivariate_normal_qmc(self, **kwargs):
  438. X = stats.qmc.MultivariateNormalQMC([0], **kwargs)
  439. return X.random(4)
  440. def multinomial_qmc(self, **kwargs):
  441. X = stats.qmc.MultinomialQMC([0.5, 0.5], 4, **kwargs)
  442. return X.random(4)
  443. def permutation_method(self, **kwargs):
  444. rng = np.random.default_rng(3458934594269824562)
  445. data = tuple(rng.random((2, 100)))
  446. method = stats.PermutationMethod(**kwargs)
  447. return stats.pearsonr(*data, method=method).pvalue
  448. def bootstrap_method(self, **kwargs):
  449. rng = np.random.default_rng(3458934594269824562)
  450. data = tuple(rng.random((2, 100)))
  451. res = stats.pearsonr(*data)
  452. method = stats.BootstrapMethod(**kwargs)
  453. return res.confidence_interval(method=method)
  454. @pytest.mark.fail_slow(10)
  455. @pytest.mark.slow
  456. @pytest.mark.parametrize("method, arg_name", [
  457. (kmeans, "seed"),
  458. (kmeans2, "seed"),
  459. (barycentric, "random_state"),
  460. (clarkson_woodruff_transform, "seed"),
  461. (basinhopping, "seed"),
  462. (differential_evolution, "seed"),
  463. (dual_annealing, "seed"),
  464. (check_grad, "seed"),
  465. (random_array, 'random_state'),
  466. (random, 'random_state'),
  467. (rand, 'random_state'),
  468. (random_rotation, "random_state"),
  469. (goodness_of_fit, "random_state"),
  470. (permutation_test, "random_state"),
  471. (bootstrap, "random_state"),
  472. (permutation_method, "random_state"),
  473. (bootstrap_method, "random_state"),
  474. (dunnett, "random_state"),
  475. (sobol_indices, "random_state"),
  476. (halton, "seed"),
  477. (sobol, "seed"),
  478. (latin_hypercube, "seed"),
  479. (poisson_disk, "seed"),
  480. (multivariate_normal_qmc, "seed"),
  481. (multinomial_qmc, "seed"),
  482. ])
  483. def test_rng_deterministic(self, method, arg_name):
  484. np.random.seed(None)
  485. seed = 2949672964
  486. rng = np.random.default_rng(seed)
  487. message = "got multiple values for argument now known as `rng`"
  488. with pytest.raises(TypeError, match=message):
  489. method(self, **{'rng': rng, arg_name: seed})
  490. rng = np.random.default_rng(seed)
  491. res1 = method(self, rng=rng)
  492. res2 = method(self, rng=seed)
  493. assert_equal(res2, res1)
  494. if method.__name__ in {"dunnett", "sobol_indices"}:
  495. # the two kwargs have essentially the same behavior for these functions
  496. res3 = method(self, **{arg_name: seed})
  497. assert_equal(res3, res1)
  498. return
  499. rng = np.random.RandomState(seed)
  500. res1 = method(self, **{arg_name: rng})
  501. res2 = method(self, **{arg_name: seed})
  502. if method.__name__ in {"halton", "sobol", "latin_hypercube", "poisson_disk",
  503. "multivariate_normal_qmc", "multinomial_qmc"}:
  504. # For these, passing `random_state=RandomState(seed)` is not the same as
  505. # passing integer `seed`.
  506. res1b = method(self, **{arg_name: np.random.RandomState(seed)})
  507. assert_equal(res1b, res1)
  508. res2b = method(self, **{arg_name: seed})
  509. assert_equal(res2b, res2)
  510. return
  511. np.random.seed(seed)
  512. res3 = method(self, **{arg_name: None})
  513. assert_equal(res2, res1)
  514. assert_equal(res3, res1)