test_smoke.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882
  1. import pickle
  2. from dataclasses import dataclass
  3. from functools import partial
  4. import pytest
  5. import numpy as np
  6. from numpy.random import MT19937, PCG64, PCG64DXSM, SFC64, Generator, Philox
  7. from numpy.testing import assert_, assert_array_equal, assert_equal
  8. DTYPES_BOOL_INT_UINT = (np.bool, np.int8, np.int16, np.int32, np.int64,
  9. np.uint8, np.uint16, np.uint32, np.uint64)
  10. def params_0(f):
  11. val = f()
  12. assert_(np.isscalar(val))
  13. val = f(10)
  14. assert_(val.shape == (10,))
  15. val = f((10, 10))
  16. assert_(val.shape == (10, 10))
  17. val = f((10, 10, 10))
  18. assert_(val.shape == (10, 10, 10))
  19. val = f(size=(5, 5))
  20. assert_(val.shape == (5, 5))
  21. def params_1(f, bounded=False):
  22. a = 5.0
  23. b = np.arange(2.0, 12.0)
  24. c = np.arange(2.0, 102.0).reshape((10, 10))
  25. d = np.arange(2.0, 1002.0).reshape((10, 10, 10))
  26. e = np.array([2.0, 3.0])
  27. g = np.arange(2.0, 12.0).reshape((1, 10, 1))
  28. if bounded:
  29. a = 0.5
  30. b = b / (1.5 * b.max())
  31. c = c / (1.5 * c.max())
  32. d = d / (1.5 * d.max())
  33. e = e / (1.5 * e.max())
  34. g = g / (1.5 * g.max())
  35. # Scalar
  36. f(a)
  37. # Scalar - size
  38. f(a, size=(10, 10))
  39. # 1d
  40. f(b)
  41. # 2d
  42. f(c)
  43. # 3d
  44. f(d)
  45. # 1d size
  46. f(b, size=10)
  47. # 2d - size - broadcast
  48. f(e, size=(10, 2))
  49. # 3d - size
  50. f(g, size=(10, 10, 10))
  51. def comp_state(state1, state2):
  52. identical = True
  53. if isinstance(state1, dict):
  54. for key in state1:
  55. identical &= comp_state(state1[key], state2[key])
  56. elif type(state1) != type(state2):
  57. identical &= type(state1) == type(state2)
  58. elif (isinstance(state1, (list, tuple, np.ndarray)) and isinstance(
  59. state2, (list, tuple, np.ndarray))):
  60. for s1, s2 in zip(state1, state2):
  61. identical &= comp_state(s1, s2)
  62. else:
  63. identical &= state1 == state2
  64. return identical
  65. def warmup(rg, n=None):
  66. if n is None:
  67. n = 11 + np.random.randint(0, 20)
  68. rg.standard_normal(n)
  69. rg.standard_normal(n)
  70. rg.standard_normal(n, dtype=np.float32)
  71. rg.standard_normal(n, dtype=np.float32)
  72. rg.integers(0, 2 ** 24, n, dtype=np.uint64)
  73. rg.integers(0, 2 ** 48, n, dtype=np.uint64)
  74. rg.standard_gamma(11.0, n)
  75. rg.standard_gamma(11.0, n, dtype=np.float32)
  76. rg.random(n, dtype=np.float64)
  77. rg.random(n, dtype=np.float32)
  78. @dataclass
  79. class RNGData:
  80. bit_generator: type[np.random.BitGenerator]
  81. advance: int
  82. seed: list[int]
  83. rg: Generator
  84. seed_vector_bits: int
  85. class RNG:
  86. @classmethod
  87. def _create_rng(cls):
  88. # Overridden in test classes. Place holder to silence IDE noise
  89. bit_generator = PCG64
  90. advance = None
  91. seed = [12345]
  92. rg = Generator(bit_generator(*seed))
  93. seed_vector_bits = 64
  94. return RNGData(bit_generator, advance, seed, rg, seed_vector_bits)
  95. def test_init(self):
  96. data = self._create_rng()
  97. data.rg = Generator(data.bit_generator())
  98. state = data.rg.bit_generator.state
  99. data.rg.standard_normal(1)
  100. data.rg.standard_normal(1)
  101. data.rg.bit_generator.state = state
  102. new_state = data.rg.bit_generator.state
  103. assert_(comp_state(state, new_state))
  104. def test_advance(self):
  105. data = self._create_rng()
  106. state = data.rg.bit_generator.state
  107. if hasattr(data.rg.bit_generator, 'advance'):
  108. data.rg.bit_generator.advance(data.advance)
  109. assert_(not comp_state(state, data.rg.bit_generator.state))
  110. else:
  111. bitgen_name = data.rg.bit_generator.__class__.__name__
  112. pytest.skip(f'Advance is not supported by {bitgen_name}')
  113. def test_jump(self):
  114. rg = self._create_rng().rg
  115. state = rg.bit_generator.state
  116. if hasattr(rg.bit_generator, 'jumped'):
  117. bit_gen2 = rg.bit_generator.jumped()
  118. jumped_state = bit_gen2.state
  119. assert_(not comp_state(state, jumped_state))
  120. rg.random(2 * 3 * 5 * 7 * 11 * 13 * 17)
  121. rg.bit_generator.state = state
  122. bit_gen3 = rg.bit_generator.jumped()
  123. rejumped_state = bit_gen3.state
  124. assert_(comp_state(jumped_state, rejumped_state))
  125. else:
  126. bitgen_name = rg.bit_generator.__class__.__name__
  127. if bitgen_name not in ('SFC64',):
  128. raise AttributeError(f'no "jumped" in {bitgen_name}')
  129. pytest.skip(f'Jump is not supported by {bitgen_name}')
  130. def test_uniform(self):
  131. rg = self._create_rng().rg
  132. r = rg.uniform(-1.0, 0.0, size=10)
  133. assert_(len(r) == 10)
  134. assert_((r > -1).all())
  135. assert_((r <= 0).all())
  136. def test_uniform_array(self):
  137. rg = self._create_rng().rg
  138. r = rg.uniform(np.array([-1.0] * 10), 0.0, size=10)
  139. assert_(len(r) == 10)
  140. assert_((r > -1).all())
  141. assert_((r <= 0).all())
  142. r = rg.uniform(np.array([-1.0] * 10),
  143. np.array([0.0] * 10), size=10)
  144. assert_(len(r) == 10)
  145. assert_((r > -1).all())
  146. assert_((r <= 0).all())
  147. r = rg.uniform(-1.0, np.array([0.0] * 10), size=10)
  148. assert_(len(r) == 10)
  149. assert_((r > -1).all())
  150. assert_((r <= 0).all())
  151. def test_random(self):
  152. rg = self._create_rng().rg
  153. assert_(len(rg.random(10)) == 10)
  154. params_0(rg.random)
  155. def test_standard_normal_zig(self):
  156. rg = self._create_rng().rg
  157. assert_(len(rg.standard_normal(10)) == 10)
  158. def test_standard_normal(self):
  159. rg = self._create_rng().rg
  160. assert_(len(rg.standard_normal(10)) == 10)
  161. params_0(rg.standard_normal)
  162. def test_standard_gamma(self):
  163. rg = self._create_rng().rg
  164. assert_(len(rg.standard_gamma(10, 10)) == 10)
  165. assert_(len(rg.standard_gamma(np.array([10] * 10), 10)) == 10)
  166. params_1(rg.standard_gamma)
  167. def test_standard_exponential(self):
  168. rg = self._create_rng().rg
  169. assert_(len(rg.standard_exponential(10)) == 10)
  170. params_0(rg.standard_exponential)
  171. def test_standard_exponential_float(self):
  172. rg = self._create_rng().rg
  173. randoms = rg.standard_exponential(10, dtype='float32')
  174. assert_(len(randoms) == 10)
  175. assert randoms.dtype == np.float32
  176. params_0(partial(rg.standard_exponential, dtype='float32'))
  177. def test_standard_exponential_float_log(self):
  178. rg = self._create_rng().rg
  179. randoms = rg.standard_exponential(10, dtype='float32',
  180. method='inv')
  181. assert_(len(randoms) == 10)
  182. assert randoms.dtype == np.float32
  183. params_0(partial(rg.standard_exponential, dtype='float32',
  184. method='inv'))
  185. def test_standard_cauchy(self):
  186. rg = self._create_rng().rg
  187. assert_(len(rg.standard_cauchy(10)) == 10)
  188. params_0(rg.standard_cauchy)
  189. def test_standard_t(self):
  190. rg = self._create_rng().rg
  191. assert_(len(rg.standard_t(10, 10)) == 10)
  192. params_1(rg.standard_t)
  193. def test_binomial(self):
  194. rg = self._create_rng().rg
  195. assert_(rg.binomial(10, .5) >= 0)
  196. assert_(rg.binomial(1000, .5) >= 0)
  197. def test_reset_state(self):
  198. rg = self._create_rng().rg
  199. state = rg.bit_generator.state
  200. int_1 = rg.integers(2**31)
  201. rg.bit_generator.state = state
  202. int_2 = rg.integers(2**31)
  203. assert_(int_1 == int_2)
  204. def test_entropy_init(self):
  205. bit_generator = self._create_rng().bit_generator
  206. rg = Generator(bit_generator())
  207. rg2 = Generator(bit_generator())
  208. assert_(not comp_state(rg.bit_generator.state,
  209. rg2.bit_generator.state))
  210. def test_seed(self):
  211. data = self._create_rng()
  212. rg = Generator(data.bit_generator(*data.seed))
  213. rg2 = Generator(data.bit_generator(*data.seed))
  214. rg.random()
  215. rg2.random()
  216. assert_(comp_state(rg.bit_generator.state, rg2.bit_generator.state))
  217. def test_reset_state_gauss(self):
  218. data = self._create_rng()
  219. rg = Generator(data.bit_generator(*data.seed))
  220. rg.standard_normal()
  221. state = rg.bit_generator.state
  222. n1 = rg.standard_normal(size=10)
  223. rg2 = Generator(data.bit_generator())
  224. rg2.bit_generator.state = state
  225. n2 = rg2.standard_normal(size=10)
  226. assert_array_equal(n1, n2)
  227. def test_reset_state_uint32(self):
  228. data = self._create_rng()
  229. rg = Generator(data.bit_generator(*data.seed))
  230. rg.integers(0, 2 ** 24, 120, dtype=np.uint32)
  231. state = rg.bit_generator.state
  232. n1 = rg.integers(0, 2 ** 24, 10, dtype=np.uint32)
  233. rg2 = Generator(data.bit_generator())
  234. rg2.bit_generator.state = state
  235. n2 = rg2.integers(0, 2 ** 24, 10, dtype=np.uint32)
  236. assert_array_equal(n1, n2)
  237. def test_reset_state_float(self):
  238. data = self._create_rng()
  239. rg = Generator(data.bit_generator(*data.seed))
  240. rg.random(dtype='float32')
  241. state = rg.bit_generator.state
  242. n1 = rg.random(size=10, dtype='float32')
  243. rg2 = Generator(data.bit_generator())
  244. rg2.bit_generator.state = state
  245. n2 = rg2.random(size=10, dtype='float32')
  246. assert_((n1 == n2).all())
  247. def test_shuffle(self):
  248. rg = self._create_rng().rg
  249. original = np.arange(200, 0, -1)
  250. permuted = rg.permutation(original)
  251. assert_((original != permuted).any())
  252. def test_permutation(self):
  253. rg = self._create_rng().rg
  254. original = np.arange(200, 0, -1)
  255. permuted = rg.permutation(original)
  256. assert_((original != permuted).any())
  257. def test_beta(self):
  258. rg = self._create_rng().rg
  259. vals = rg.beta(2.0, 2.0, 10)
  260. assert_(len(vals) == 10)
  261. vals = rg.beta(np.array([2.0] * 10), 2.0)
  262. assert_(len(vals) == 10)
  263. vals = rg.beta(2.0, np.array([2.0] * 10))
  264. assert_(len(vals) == 10)
  265. vals = rg.beta(np.array([2.0] * 10), np.array([2.0] * 10))
  266. assert_(len(vals) == 10)
  267. vals = rg.beta(np.array([2.0] * 10), np.array([[2.0]] * 10))
  268. assert_(vals.shape == (10, 10))
  269. def test_bytes(self):
  270. rg = self._create_rng().rg
  271. vals = rg.bytes(10)
  272. assert_(len(vals) == 10)
  273. def test_chisquare(self):
  274. rg = self._create_rng().rg
  275. vals = rg.chisquare(2.0, 10)
  276. assert_(len(vals) == 10)
  277. params_1(rg.chisquare)
  278. def test_exponential(self):
  279. rg = self._create_rng().rg
  280. vals = rg.exponential(2.0, 10)
  281. assert_(len(vals) == 10)
  282. params_1(rg.exponential)
  283. def test_f(self):
  284. rg = self._create_rng().rg
  285. vals = rg.f(3, 1000, 10)
  286. assert_(len(vals) == 10)
  287. def test_gamma(self):
  288. rg = self._create_rng().rg
  289. vals = rg.gamma(3, 2, 10)
  290. assert_(len(vals) == 10)
  291. def test_geometric(self):
  292. rg = self._create_rng().rg
  293. vals = rg.geometric(0.5, 10)
  294. assert_(len(vals) == 10)
  295. params_1(rg.exponential, bounded=True)
  296. def test_gumbel(self):
  297. rg = self._create_rng().rg
  298. vals = rg.gumbel(2.0, 2.0, 10)
  299. assert_(len(vals) == 10)
  300. def test_laplace(self):
  301. rg = self._create_rng().rg
  302. vals = rg.laplace(2.0, 2.0, 10)
  303. assert_(len(vals) == 10)
  304. def test_logitic(self):
  305. rg = self._create_rng().rg
  306. vals = rg.logistic(2.0, 2.0, 10)
  307. assert_(len(vals) == 10)
  308. def test_logseries(self):
  309. rg = self._create_rng().rg
  310. vals = rg.logseries(0.5, 10)
  311. assert_(len(vals) == 10)
  312. def test_negative_binomial(self):
  313. rg = self._create_rng().rg
  314. vals = rg.negative_binomial(10, 0.2, 10)
  315. assert_(len(vals) == 10)
  316. def test_noncentral_chisquare(self):
  317. rg = self._create_rng().rg
  318. vals = rg.noncentral_chisquare(10, 2, 10)
  319. assert_(len(vals) == 10)
  320. def test_noncentral_f(self):
  321. rg = self._create_rng().rg
  322. vals = rg.noncentral_f(3, 1000, 2, 10)
  323. assert_(len(vals) == 10)
  324. vals = rg.noncentral_f(np.array([3] * 10), 1000, 2)
  325. assert_(len(vals) == 10)
  326. vals = rg.noncentral_f(3, np.array([1000] * 10), 2)
  327. assert_(len(vals) == 10)
  328. vals = rg.noncentral_f(3, 1000, np.array([2] * 10))
  329. assert_(len(vals) == 10)
  330. def test_normal(self):
  331. rg = self._create_rng().rg
  332. vals = rg.normal(10, 0.2, 10)
  333. assert_(len(vals) == 10)
  334. def test_pareto(self):
  335. rg = self._create_rng().rg
  336. vals = rg.pareto(3.0, 10)
  337. assert_(len(vals) == 10)
  338. def test_poisson(self):
  339. rg = self._create_rng().rg
  340. vals = rg.poisson(10, 10)
  341. assert_(len(vals) == 10)
  342. vals = rg.poisson(np.array([10] * 10))
  343. assert_(len(vals) == 10)
  344. params_1(rg.poisson)
  345. def test_power(self):
  346. rg = self._create_rng().rg
  347. vals = rg.power(0.2, 10)
  348. assert_(len(vals) == 10)
  349. def test_integers(self):
  350. rg = self._create_rng().rg
  351. vals = rg.integers(10, 20, 10)
  352. assert_(len(vals) == 10)
  353. def test_rayleigh(self):
  354. rg = self._create_rng().rg
  355. vals = rg.rayleigh(0.2, 10)
  356. assert_(len(vals) == 10)
  357. params_1(rg.rayleigh, bounded=True)
  358. def test_vonmises(self):
  359. rg = self._create_rng().rg
  360. vals = rg.vonmises(10, 0.2, 10)
  361. assert_(len(vals) == 10)
  362. def test_wald(self):
  363. rg = self._create_rng().rg
  364. vals = rg.wald(1.0, 1.0, 10)
  365. assert_(len(vals) == 10)
  366. def test_weibull(self):
  367. rg = self._create_rng().rg
  368. vals = rg.weibull(1.0, 10)
  369. assert_(len(vals) == 10)
  370. def test_zipf(self):
  371. rg = self._create_rng().rg
  372. vec_1d = np.arange(2.0, 102.0)
  373. vec_2d = np.arange(2.0, 102.0)[None, :]
  374. mat = np.arange(2.0, 102.0, 0.01).reshape((100, 100))
  375. vals = rg.zipf(10, 10)
  376. assert_(len(vals) == 10)
  377. vals = rg.zipf(vec_1d)
  378. assert_(len(vals) == 100)
  379. vals = rg.zipf(vec_2d)
  380. assert_(vals.shape == (1, 100))
  381. vals = rg.zipf(mat)
  382. assert_(vals.shape == (100, 100))
  383. def test_hypergeometric(self):
  384. rg = self._create_rng().rg
  385. vals = rg.hypergeometric(25, 25, 20)
  386. assert_(np.isscalar(vals))
  387. vals = rg.hypergeometric(np.array([25] * 10), 25, 20)
  388. assert_(vals.shape == (10,))
  389. def test_triangular(self):
  390. rg = self._create_rng().rg
  391. vals = rg.triangular(-5, 0, 5)
  392. assert_(np.isscalar(vals))
  393. vals = rg.triangular(-5, np.array([0] * 10), 5)
  394. assert_(vals.shape == (10,))
  395. def test_multivariate_normal(self):
  396. rg = self._create_rng().rg
  397. mean = [0, 0]
  398. cov = [[1, 0], [0, 100]] # diagonal covariance
  399. x = rg.multivariate_normal(mean, cov, 5000)
  400. assert_(x.shape == (5000, 2))
  401. x_zig = rg.multivariate_normal(mean, cov, 5000)
  402. assert_(x.shape == (5000, 2))
  403. x_inv = rg.multivariate_normal(mean, cov, 5000)
  404. assert_(x.shape == (5000, 2))
  405. assert_((x_zig != x_inv).any())
  406. def test_multinomial(self):
  407. rg = self._create_rng().rg
  408. vals = rg.multinomial(100, [1.0 / 3, 2.0 / 3])
  409. assert_(vals.shape == (2,))
  410. vals = rg.multinomial(100, [1.0 / 3, 2.0 / 3], size=10)
  411. assert_(vals.shape == (10, 2))
  412. def test_dirichlet(self):
  413. rg = self._create_rng().rg
  414. s = rg.dirichlet((10, 5, 3), 20)
  415. assert_(s.shape == (20, 3))
  416. def test_pickle(self):
  417. rg = self._create_rng().rg
  418. pick = pickle.dumps(rg)
  419. unpick = pickle.loads(pick)
  420. assert_(type(rg) == type(unpick))
  421. assert_(comp_state(rg.bit_generator.state,
  422. unpick.bit_generator.state))
  423. pick = pickle.dumps(rg)
  424. unpick = pickle.loads(pick)
  425. assert_(type(rg) == type(unpick))
  426. assert_(comp_state(rg.bit_generator.state,
  427. unpick.bit_generator.state))
  428. def test_seed_array(self):
  429. data = self._create_rng()
  430. if data.seed_vector_bits is None:
  431. bitgen_name = data.bit_generator.__name__
  432. pytest.skip(f'Vector seeding is not supported by {bitgen_name}')
  433. if data.seed_vector_bits == 32:
  434. dtype = np.uint32
  435. else:
  436. dtype = np.uint64
  437. seed = np.array([1], dtype=dtype)
  438. bg = data.bit_generator(seed)
  439. state1 = bg.state
  440. bg = data.bit_generator(1)
  441. state2 = bg.state
  442. assert_(comp_state(state1, state2))
  443. seed = np.arange(4, dtype=dtype)
  444. bg = data.bit_generator(seed)
  445. state1 = bg.state
  446. bg = data.bit_generator(seed[0])
  447. state2 = bg.state
  448. assert_(not comp_state(state1, state2))
  449. seed = np.arange(1500, dtype=dtype)
  450. bg = data.bit_generator(seed)
  451. state1 = bg.state
  452. bg = data.bit_generator(seed[0])
  453. state2 = bg.state
  454. assert_(not comp_state(state1, state2))
  455. seed = 2 ** np.mod(np.arange(1500, dtype=dtype),
  456. data.seed_vector_bits - 1) + 1
  457. bg = data.bit_generator(seed)
  458. state1 = bg.state
  459. bg = data.bit_generator(seed[0])
  460. state2 = bg.state
  461. assert_(not comp_state(state1, state2))
  462. def test_uniform_float(self):
  463. bit_generator = self._create_rng().bit_generator
  464. rg = Generator(bit_generator(12345))
  465. warmup(rg)
  466. state = rg.bit_generator.state
  467. r1 = rg.random(11, dtype=np.float32)
  468. rg2 = Generator(bit_generator())
  469. warmup(rg2)
  470. rg2.bit_generator.state = state
  471. r2 = rg2.random(11, dtype=np.float32)
  472. assert_array_equal(r1, r2)
  473. assert_equal(r1.dtype, np.float32)
  474. assert_(comp_state(rg.bit_generator.state, rg2.bit_generator.state))
  475. def test_gamma_floats(self):
  476. bit_generator = self._create_rng().bit_generator
  477. rg = Generator(bit_generator())
  478. warmup(rg)
  479. state = rg.bit_generator.state
  480. r1 = rg.standard_gamma(4.0, 11, dtype=np.float32)
  481. rg2 = Generator(bit_generator())
  482. warmup(rg2)
  483. rg2.bit_generator.state = state
  484. r2 = rg2.standard_gamma(4.0, 11, dtype=np.float32)
  485. assert_array_equal(r1, r2)
  486. assert_equal(r1.dtype, np.float32)
  487. assert_(comp_state(rg.bit_generator.state, rg2.bit_generator.state))
  488. def test_normal_floats(self):
  489. bit_generator = self._create_rng().bit_generator
  490. rg = Generator(bit_generator())
  491. warmup(rg)
  492. state = rg.bit_generator.state
  493. r1 = rg.standard_normal(11, dtype=np.float32)
  494. rg2 = Generator(bit_generator())
  495. warmup(rg2)
  496. rg2.bit_generator.state = state
  497. r2 = rg2.standard_normal(11, dtype=np.float32)
  498. assert_array_equal(r1, r2)
  499. assert_equal(r1.dtype, np.float32)
  500. assert_(comp_state(rg.bit_generator.state, rg2.bit_generator.state))
  501. def test_normal_zig_floats(self):
  502. bit_generator = self._create_rng().bit_generator
  503. rg = Generator(bit_generator())
  504. warmup(rg)
  505. state = rg.bit_generator.state
  506. r1 = rg.standard_normal(11, dtype=np.float32)
  507. rg2 = Generator(bit_generator())
  508. warmup(rg2)
  509. rg2.bit_generator.state = state
  510. r2 = rg2.standard_normal(11, dtype=np.float32)
  511. assert_array_equal(r1, r2)
  512. assert_equal(r1.dtype, np.float32)
  513. assert_(comp_state(rg.bit_generator.state, rg2.bit_generator.state))
  514. def test_output_fill(self):
  515. rg = self._create_rng().rg
  516. state = rg.bit_generator.state
  517. size = (31, 7, 97)
  518. existing = np.empty(size)
  519. rg.bit_generator.state = state
  520. rg.standard_normal(out=existing)
  521. rg.bit_generator.state = state
  522. direct = rg.standard_normal(size=size)
  523. assert_equal(direct, existing)
  524. sized = np.empty(size)
  525. rg.bit_generator.state = state
  526. rg.standard_normal(out=sized, size=sized.shape)
  527. existing = np.empty(size, dtype=np.float32)
  528. rg.bit_generator.state = state
  529. rg.standard_normal(out=existing, dtype=np.float32)
  530. rg.bit_generator.state = state
  531. direct = rg.standard_normal(size=size, dtype=np.float32)
  532. assert_equal(direct, existing)
  533. def test_output_filling_uniform(self):
  534. rg = self._create_rng().rg
  535. state = rg.bit_generator.state
  536. size = (31, 7, 97)
  537. existing = np.empty(size)
  538. rg.bit_generator.state = state
  539. rg.random(out=existing)
  540. rg.bit_generator.state = state
  541. direct = rg.random(size=size)
  542. assert_equal(direct, existing)
  543. existing = np.empty(size, dtype=np.float32)
  544. rg.bit_generator.state = state
  545. rg.random(out=existing, dtype=np.float32)
  546. rg.bit_generator.state = state
  547. direct = rg.random(size=size, dtype=np.float32)
  548. assert_equal(direct, existing)
  549. def test_output_filling_exponential(self):
  550. rg = self._create_rng().rg
  551. state = rg.bit_generator.state
  552. size = (31, 7, 97)
  553. existing = np.empty(size)
  554. rg.bit_generator.state = state
  555. rg.standard_exponential(out=existing)
  556. rg.bit_generator.state = state
  557. direct = rg.standard_exponential(size=size)
  558. assert_equal(direct, existing)
  559. existing = np.empty(size, dtype=np.float32)
  560. rg.bit_generator.state = state
  561. rg.standard_exponential(out=existing, dtype=np.float32)
  562. rg.bit_generator.state = state
  563. direct = rg.standard_exponential(size=size, dtype=np.float32)
  564. assert_equal(direct, existing)
  565. def test_output_filling_gamma(self):
  566. rg = self._create_rng().rg
  567. state = rg.bit_generator.state
  568. size = (31, 7, 97)
  569. existing = np.zeros(size)
  570. rg.bit_generator.state = state
  571. rg.standard_gamma(1.0, out=existing)
  572. rg.bit_generator.state = state
  573. direct = rg.standard_gamma(1.0, size=size)
  574. assert_equal(direct, existing)
  575. existing = np.zeros(size, dtype=np.float32)
  576. rg.bit_generator.state = state
  577. rg.standard_gamma(1.0, out=existing, dtype=np.float32)
  578. rg.bit_generator.state = state
  579. direct = rg.standard_gamma(1.0, size=size, dtype=np.float32)
  580. assert_equal(direct, existing)
  581. def test_output_filling_gamma_broadcast(self):
  582. rg = self._create_rng().rg
  583. state = rg.bit_generator.state
  584. size = (31, 7, 97)
  585. mu = np.arange(97.0) + 1.0
  586. existing = np.zeros(size)
  587. rg.bit_generator.state = state
  588. rg.standard_gamma(mu, out=existing)
  589. rg.bit_generator.state = state
  590. direct = rg.standard_gamma(mu, size=size)
  591. assert_equal(direct, existing)
  592. existing = np.zeros(size, dtype=np.float32)
  593. rg.bit_generator.state = state
  594. rg.standard_gamma(mu, out=existing, dtype=np.float32)
  595. rg.bit_generator.state = state
  596. direct = rg.standard_gamma(mu, size=size, dtype=np.float32)
  597. assert_equal(direct, existing)
  598. def test_output_fill_error(self):
  599. rg = self._create_rng().rg
  600. size = (31, 7, 97)
  601. existing = np.empty(size)
  602. with pytest.raises(TypeError):
  603. rg.standard_normal(out=existing, dtype=np.float32)
  604. with pytest.raises(ValueError):
  605. rg.standard_normal(out=existing[::3])
  606. existing = np.empty(size, dtype=np.float32)
  607. with pytest.raises(TypeError):
  608. rg.standard_normal(out=existing, dtype=np.float64)
  609. existing = np.zeros(size, dtype=np.float32)
  610. with pytest.raises(TypeError):
  611. rg.standard_gamma(1.0, out=existing, dtype=np.float64)
  612. with pytest.raises(ValueError):
  613. rg.standard_gamma(1.0, out=existing[::3], dtype=np.float32)
  614. existing = np.zeros(size, dtype=np.float64)
  615. with pytest.raises(TypeError):
  616. rg.standard_gamma(1.0, out=existing, dtype=np.float32)
  617. with pytest.raises(ValueError):
  618. rg.standard_gamma(1.0, out=existing[::3])
  619. @pytest.mark.parametrize("dtype", DTYPES_BOOL_INT_UINT)
  620. def test_integers_broadcast(self, dtype):
  621. rg = self._create_rng().rg
  622. initial_state = rg.bit_generator.state
  623. def reset_state(rng):
  624. rng.bit_generator.state = initial_state
  625. if dtype == np.bool:
  626. upper = 2
  627. lower = 0
  628. else:
  629. info = np.iinfo(dtype)
  630. upper = int(info.max) + 1
  631. lower = info.min
  632. reset_state(rg)
  633. rg.bit_generator.state = initial_state
  634. a = rg.integers(lower, [upper] * 10, dtype=dtype)
  635. reset_state(rg)
  636. b = rg.integers([lower] * 10, upper, dtype=dtype)
  637. assert_equal(a, b)
  638. reset_state(rg)
  639. c = rg.integers(lower, upper, size=10, dtype=dtype)
  640. assert_equal(a, c)
  641. reset_state(rg)
  642. d = rg.integers(np.array(
  643. [lower] * 10), np.array([upper], dtype=object), size=10,
  644. dtype=dtype)
  645. assert_equal(a, d)
  646. reset_state(rg)
  647. e = rg.integers(
  648. np.array([lower] * 10), np.array([upper] * 10), size=10,
  649. dtype=dtype)
  650. assert_equal(a, e)
  651. reset_state(rg)
  652. a = rg.integers(0, upper, size=10, dtype=dtype)
  653. reset_state(rg)
  654. b = rg.integers([upper] * 10, dtype=dtype)
  655. assert_equal(a, b)
  656. @pytest.mark.parametrize("dtype", DTYPES_BOOL_INT_UINT)
  657. def test_integers_numpy(self, dtype):
  658. rg = self._create_rng().rg
  659. high = np.array([1])
  660. low = np.array([0])
  661. out = rg.integers(low, high, dtype=dtype)
  662. assert out.shape == (1,)
  663. out = rg.integers(low[0], high, dtype=dtype)
  664. assert out.shape == (1,)
  665. out = rg.integers(low, high[0], dtype=dtype)
  666. assert out.shape == (1,)
  667. @pytest.mark.parametrize("dtype", DTYPES_BOOL_INT_UINT)
  668. def test_integers_broadcast_errors(self, dtype):
  669. rg = self._create_rng().rg
  670. if dtype == np.bool:
  671. upper = 2
  672. lower = 0
  673. else:
  674. info = np.iinfo(dtype)
  675. upper = int(info.max) + 1
  676. lower = info.min
  677. with pytest.raises(ValueError):
  678. rg.integers(lower, [upper + 1] * 10, dtype=dtype)
  679. with pytest.raises(ValueError):
  680. rg.integers(lower - 1, [upper] * 10, dtype=dtype)
  681. with pytest.raises(ValueError):
  682. rg.integers([lower - 1], [upper] * 10, dtype=dtype)
  683. with pytest.raises(ValueError):
  684. rg.integers([0], [0], dtype=dtype)
  685. class TestMT19937(RNG):
  686. @classmethod
  687. def _create_rng(cls):
  688. bit_generator = MT19937
  689. advance = None
  690. seed = [2 ** 21 + 2 ** 16 + 2 ** 5 + 1]
  691. rg = Generator(bit_generator(*seed))
  692. seed_vector_bits = 32
  693. return RNGData(bit_generator, advance, seed, rg, seed_vector_bits)
  694. def test_numpy_state(self):
  695. rg = self._create_rng().rg
  696. nprg = np.random.RandomState()
  697. nprg.standard_normal(99)
  698. state = nprg.get_state()
  699. rg.bit_generator.state = state
  700. state2 = rg.bit_generator.state
  701. assert_((state[1] == state2['state']['key']).all())
  702. assert_(state[2] == state2['state']['pos'])
  703. class TestPhilox(RNG):
  704. @classmethod
  705. def _create_rng(cls):
  706. bit_generator = Philox
  707. advance = 2**63 + 2**31 + 2**15 + 1
  708. seed = [12345]
  709. rg = Generator(bit_generator(*seed))
  710. seed_vector_bits = 64
  711. return RNGData(bit_generator, advance, seed, rg, seed_vector_bits)
  712. class TestSFC64(RNG):
  713. @classmethod
  714. def _create_rng(cls):
  715. bit_generator = SFC64
  716. advance = None
  717. seed = [12345]
  718. rg = Generator(bit_generator(*seed))
  719. seed_vector_bits = 192
  720. return RNGData(bit_generator, advance, seed, rg, seed_vector_bits)
  721. class TestPCG64(RNG):
  722. @classmethod
  723. def _create_rng(cls):
  724. bit_generator = PCG64
  725. advance = 2**63 + 2**31 + 2**15 + 1
  726. seed = [12345]
  727. rg = Generator(bit_generator(*seed))
  728. seed_vector_bits = 64
  729. return RNGData(bit_generator, advance, seed, rg, seed_vector_bits)
  730. class TestPCG64DXSM(RNG):
  731. @classmethod
  732. def _create_rng(cls):
  733. bit_generator = PCG64DXSM
  734. advance = 2**63 + 2**31 + 2**15 + 1
  735. seed = [12345]
  736. rg = Generator(bit_generator(*seed))
  737. seed_vector_bits = 64
  738. return RNGData(bit_generator, advance, seed, rg, seed_vector_bits)
  739. class TestDefaultRNG(RNG):
  740. @classmethod
  741. def _create_rng(cls):
  742. # This will duplicate some tests that directly instantiate a fresh
  743. # Generator(), but that's okay.
  744. bit_generator = PCG64
  745. advance = 2**63 + 2**31 + 2**15 + 1
  746. seed = [12345]
  747. rg = np.random.default_rng(*seed)
  748. seed_vector_bits = 64
  749. return RNGData(bit_generator, advance, seed, rg, seed_vector_bits)
  750. def test_default_is_pcg64(self):
  751. # In order to change the default BitGenerator, we'll go through
  752. # a deprecation cycle to move to a different function.
  753. rg = self._create_rng().rg
  754. assert_(isinstance(rg.bit_generator, PCG64))
  755. def test_seed(self):
  756. np.random.default_rng()
  757. np.random.default_rng(None)
  758. np.random.default_rng(12345)
  759. np.random.default_rng(0)
  760. np.random.default_rng(43660444402423911716352051725018508569)
  761. np.random.default_rng([43660444402423911716352051725018508569,
  762. 279705150948142787361475340226491943209])
  763. with pytest.raises(ValueError):
  764. np.random.default_rng(-1)
  765. with pytest.raises(ValueError):
  766. np.random.default_rng([12345, -1])