test_discrete_basic.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. import warnings
  2. import numpy.testing as npt
  3. from numpy.testing import assert_allclose
  4. import numpy as np
  5. import pytest
  6. from scipy import stats
  7. from scipy.special import _ufuncs
  8. from .common_tests import (check_normalization, check_moment,
  9. check_mean_expect,
  10. check_var_expect, check_skew_expect,
  11. check_kurt_expect, check_entropy,
  12. check_private_entropy, check_edge_support,
  13. check_named_args, check_random_state_property,
  14. check_pickling, check_rvs_broadcast,
  15. check_freezing,)
  16. from scipy.stats._distr_params import distdiscrete, invdistdiscrete
  17. from scipy.stats._distn_infrastructure import rv_discrete_frozen
  18. vals = ([1, 2, 3, 4], [0.1, 0.2, 0.3, 0.4])
  19. distdiscrete += [[stats.rv_discrete(values=vals), ()]]
  20. # For these distributions, test_discrete_basic only runs with test mode full
  21. distslow = {'nhypergeom'}
  22. # Override number of ULPs adjustment for `check_cdf_ppf`
  23. roundtrip_cdf_ppf_exceptions = {'nbinom': 30}
  24. def cases_test_discrete_basic():
  25. seen = set()
  26. for distname, arg in distdiscrete:
  27. if distname in distslow:
  28. yield pytest.param(distname, arg, distname, marks=pytest.mark.slow)
  29. else:
  30. yield distname, arg, distname not in seen
  31. seen.add(distname)
  32. @pytest.mark.parametrize('distname,arg,first_case', cases_test_discrete_basic())
  33. def test_discrete_basic(distname, arg, first_case, num_parallel_threads):
  34. if (isinstance(distname, str) and distname.startswith('nchypergeom')
  35. and num_parallel_threads > 1):
  36. pytest.skip(reason='nchypergeom has a global random generator')
  37. try:
  38. distfn = getattr(stats, distname)
  39. except TypeError:
  40. distfn = distname
  41. distname = 'sample distribution'
  42. rng = np.random.RandomState(9765456)
  43. rvs = distfn.rvs(*arg, size=2000, random_state=rng)
  44. supp = np.unique(rvs)
  45. m, v = distfn.stats(*arg)
  46. check_cdf_ppf(distfn, arg, supp, distname + ' cdf_ppf')
  47. check_pmf_cdf(distfn, arg, distname)
  48. check_oth(distfn, arg, supp, distname + ' oth')
  49. check_edge_support(distfn, arg)
  50. alpha = 0.01
  51. check_discrete_chisquare(distfn, arg, rvs, alpha,
  52. distname + ' chisquare')
  53. if first_case:
  54. locscale_defaults = (0,)
  55. meths = [distfn.pmf, distfn.logpmf, distfn.cdf, distfn.logcdf,
  56. distfn.logsf]
  57. # make sure arguments are within support
  58. # for some distributions, this needs to be overridden
  59. spec_k = {'randint': 11, 'hypergeom': 4, 'bernoulli': 0,
  60. 'nchypergeom_wallenius': 6}
  61. k = spec_k.get(distname, 1)
  62. check_named_args(distfn, k, arg, locscale_defaults, meths)
  63. if distname != 'sample distribution':
  64. check_scale_docstring(distfn)
  65. if num_parallel_threads == 1:
  66. check_random_state_property(distfn, arg)
  67. if distname not in {'poisson_binom'}: # can't be pickled
  68. check_pickling(distfn, arg)
  69. check_freezing(distfn, arg)
  70. # Entropy
  71. check_entropy(distfn, arg, distname)
  72. if distfn.__class__._entropy != stats.rv_discrete._entropy:
  73. check_private_entropy(distfn, arg, stats.rv_discrete)
  74. @pytest.mark.parametrize('distname,arg', distdiscrete)
  75. def test_moments(distname, arg):
  76. try:
  77. distfn = getattr(stats, distname)
  78. except TypeError:
  79. distfn = distname
  80. distname = 'sample distribution'
  81. m, v, s, k = distfn.stats(*arg, moments='mvsk')
  82. check_normalization(distfn, arg, distname)
  83. # compare `stats` and `moment` methods
  84. check_moment(distfn, arg, m, v, distname)
  85. check_mean_expect(distfn, arg, m, distname)
  86. check_var_expect(distfn, arg, m, v, distname)
  87. check_skew_expect(distfn, arg, m, v, s, distname)
  88. with warnings.catch_warnings():
  89. if distname in ['zipf', 'betanbinom']:
  90. warnings.simplefilter("ignore", RuntimeWarning)
  91. check_kurt_expect(distfn, arg, m, v, k, distname)
  92. # frozen distr moments
  93. check_moment_frozen(distfn, arg, m, 1)
  94. check_moment_frozen(distfn, arg, v+m*m, 2)
  95. @pytest.mark.parametrize('dist,shape_args', distdiscrete)
  96. def test_rvs_broadcast(dist, shape_args):
  97. # If shape_only is True, it means the _rvs method of the
  98. # distribution uses more than one random number to generate a random
  99. # variate. That means the result of using rvs with broadcasting or
  100. # with a nontrivial size will not necessarily be the same as using the
  101. # numpy.vectorize'd version of rvs(), so we can only compare the shapes
  102. # of the results, not the values.
  103. # Whether or not a distribution is in the following list is an
  104. # implementation detail of the distribution, not a requirement. If
  105. # the implementation the rvs() method of a distribution changes, this
  106. # test might also have to be changed.
  107. shape_only = dist in ['betabinom', 'betanbinom', 'skellam', 'yulesimon',
  108. 'dlaplace', 'nchypergeom_fisher',
  109. 'nchypergeom_wallenius', 'poisson_binom']
  110. try:
  111. distfunc = getattr(stats, dist)
  112. except TypeError:
  113. distfunc = dist
  114. dist = f'rv_discrete(values=({dist.xk!r}, {dist.pk!r}))'
  115. loc = np.zeros(2)
  116. nargs = distfunc.numargs
  117. allargs = []
  118. bshape = []
  119. if dist == 'poisson_binom':
  120. # normal rules apply except the last axis of `p` is ignored
  121. p = np.full((3, 1, 10), 0.5)
  122. allargs = (p, loc)
  123. bshape = (3, 2)
  124. check_rvs_broadcast(distfunc, dist, allargs,
  125. bshape, shape_only, [np.dtype(int)])
  126. return
  127. # Generate shape parameter arguments...
  128. for k in range(nargs):
  129. shp = (k + 3,) + (1,)*(k + 1)
  130. param_val = shape_args[k]
  131. allargs.append(np.full(shp, param_val))
  132. bshape.insert(0, shp[0])
  133. allargs.append(loc)
  134. bshape.append(loc.size)
  135. # bshape holds the expected shape when loc, scale, and the shape
  136. # parameters are all broadcast together.
  137. check_rvs_broadcast(
  138. distfunc, dist, allargs, bshape, shape_only, [np.dtype(int)]
  139. )
  140. @pytest.mark.parametrize('dist,args', distdiscrete)
  141. def test_ppf_with_loc(dist, args):
  142. try:
  143. distfn = getattr(stats, dist)
  144. except TypeError:
  145. distfn = dist
  146. #check with a negative, no and positive relocation.
  147. rng = np.random.default_rng(5108587887)
  148. re_locs = [rng.integers(-10, -1), 0, rng.integers(1, 10)]
  149. _a, _b = distfn.support(*args)
  150. for loc in re_locs:
  151. npt.assert_array_equal(
  152. [_a-1+loc, _b+loc],
  153. [distfn.ppf(0.0, *args, loc=loc), distfn.ppf(1.0, *args, loc=loc)]
  154. )
  155. @pytest.mark.parametrize('dist, args', distdiscrete)
  156. def test_isf_with_loc(dist, args):
  157. try:
  158. distfn = getattr(stats, dist)
  159. except TypeError:
  160. distfn = dist
  161. # check with a negative, no and positive relocation.
  162. rng = np.random.default_rng(4030503535)
  163. re_locs = [rng.integers(-10, -1), 0, rng.integers(1, 10)]
  164. _a, _b = distfn.support(*args)
  165. for loc in re_locs:
  166. expected = _b + loc, _a - 1 + loc
  167. res = distfn.isf(0., *args, loc=loc), distfn.isf(1., *args, loc=loc)
  168. npt.assert_array_equal(expected, res)
  169. # test broadcasting behaviour
  170. re_locs = [rng.integers(-10, -1, size=(5, 3)),
  171. np.zeros((5, 3)),
  172. rng.integers(1, 10, size=(5, 3))]
  173. _a, _b = distfn.support(*args)
  174. for loc in re_locs:
  175. expected = _b + loc, _a - 1 + loc
  176. res = distfn.isf(0., *args, loc=loc), distfn.isf(1., *args, loc=loc)
  177. npt.assert_array_equal(expected, res)
  178. def check_cdf_ppf(distfn, arg, supp, msg):
  179. # supp is assumed to be an array of integers in the support of distfn
  180. # (but not necessarily all the integers in the support).
  181. # This test assumes that the PMF of any value in the support of the
  182. # distribution is greater than 1e-8.
  183. # cdf is a step function, and ppf(q) = min{k : cdf(k) >= q, k integer}
  184. cdf_supp = distfn.cdf(supp, *arg)
  185. # In very rare cases, the finite precision calculation of ppf(cdf(supp))
  186. # can produce an array in which an element is off by one. We nudge the
  187. # CDF values down by a few ULPs help to avoid this.
  188. n_ulps = roundtrip_cdf_ppf_exceptions.get(distfn.name, 15)
  189. cdf_supp0 = cdf_supp - n_ulps*np.spacing(cdf_supp)
  190. npt.assert_array_equal(distfn.ppf(cdf_supp0, *arg),
  191. supp, msg + '-roundtrip')
  192. # Repeat the same calculation, but with the CDF values decreased by 1e-8.
  193. npt.assert_array_equal(distfn.ppf(distfn.cdf(supp, *arg) - 1e-8, *arg),
  194. supp, msg + '-roundtrip')
  195. if not hasattr(distfn, 'xk'):
  196. _a, _b = distfn.support(*arg)
  197. supp1 = supp[supp < _b]
  198. npt.assert_array_equal(distfn.ppf(distfn.cdf(supp1, *arg) + 1e-8, *arg),
  199. supp1 + distfn.inc, msg + ' ppf-cdf-next')
  200. def check_pmf_cdf(distfn, arg, distname):
  201. if hasattr(distfn, 'xk'):
  202. index = distfn.xk
  203. else:
  204. startind = int(distfn.ppf(0.01, *arg) - 1)
  205. index = list(range(startind, startind + 10))
  206. cdfs = distfn.cdf(index, *arg)
  207. pmfs_cum = distfn.pmf(index, *arg).cumsum()
  208. atol, rtol = 1e-10, 1e-10
  209. if distname == 'skellam': # ncx2 accuracy
  210. atol, rtol = 1e-5, 1e-5
  211. npt.assert_allclose(cdfs - cdfs[0], pmfs_cum - pmfs_cum[0],
  212. atol=atol, rtol=rtol)
  213. # also check that pmf at non-integral k is zero
  214. k = np.asarray(index)
  215. k_shifted = k[:-1] + np.diff(k)/2
  216. npt.assert_equal(distfn.pmf(k_shifted, *arg), 0)
  217. # better check frozen distributions, and also when loc != 0
  218. loc = 0.5
  219. dist = distfn(loc=loc, *arg)
  220. npt.assert_allclose(dist.pmf(k[1:] + loc), np.diff(dist.cdf(k + loc)))
  221. npt.assert_equal(dist.pmf(k_shifted + loc), 0)
  222. def check_moment_frozen(distfn, arg, m, k):
  223. npt.assert_allclose(distfn(*arg).moment(k), m,
  224. atol=1e-10, rtol=1e-10)
  225. def check_oth(distfn, arg, supp, msg):
  226. # checking other methods of distfn
  227. npt.assert_allclose(distfn.sf(supp, *arg), 1. - distfn.cdf(supp, *arg),
  228. atol=1e-10, rtol=1e-10)
  229. q = np.linspace(0.01, 0.99, 20)
  230. npt.assert_allclose(distfn.isf(q, *arg), distfn.ppf(1. - q, *arg),
  231. atol=1e-10, rtol=1e-10)
  232. median_sf = distfn.isf(0.5, *arg)
  233. npt.assert_(distfn.sf(median_sf - 1, *arg) > 0.5)
  234. npt.assert_(distfn.cdf(median_sf + 1, *arg) > 0.5)
  235. def check_discrete_chisquare(distfn, arg, rvs, alpha, msg):
  236. """Perform chisquare test for random sample of a discrete distribution
  237. Parameters
  238. ----------
  239. distname : string
  240. name of distribution function
  241. arg : sequence
  242. parameters of distribution
  243. alpha : float
  244. significance level, threshold for p-value
  245. Returns
  246. -------
  247. result : bool
  248. 0 if test passes, 1 if test fails
  249. """
  250. wsupp = 0.05
  251. # construct intervals with minimum mass `wsupp`.
  252. # intervals are left-half-open as in a cdf difference
  253. _a, _b = distfn.support(*arg)
  254. lo = int(max(_a, -1000))
  255. high = int(min(_b, 1000)) + 1
  256. distsupport = range(lo, high)
  257. last = 0
  258. distsupp = [lo]
  259. distmass = []
  260. for ii in distsupport:
  261. current = distfn.cdf(ii, *arg)
  262. if current - last >= wsupp - 1e-14:
  263. distsupp.append(ii)
  264. distmass.append(current - last)
  265. last = current
  266. if current > (1 - wsupp):
  267. break
  268. if distsupp[-1] < _b:
  269. distsupp.append(_b)
  270. distmass.append(1 - last)
  271. distsupp = np.array(distsupp)
  272. distmass = np.array(distmass)
  273. # convert intervals to right-half-open as required by histogram
  274. histsupp = distsupp + 1e-8
  275. histsupp[0] = _a
  276. # find sample frequencies and perform chisquare test
  277. freq, hsupp = np.histogram(rvs, histsupp)
  278. chis, pval = stats.chisquare(np.array(freq), len(rvs)*distmass)
  279. npt.assert_(
  280. pval > alpha,
  281. f'chisquare - test for {msg} at arg = {str(arg)} with pval = {str(pval)}'
  282. )
  283. def check_scale_docstring(distfn):
  284. if distfn.__doc__ is not None:
  285. # Docstrings can be stripped if interpreter is run with -OO
  286. npt.assert_('scale' not in distfn.__doc__)
  287. @pytest.mark.parametrize('method', ['pmf', 'logpmf', 'cdf', 'logcdf',
  288. 'sf', 'logsf', 'ppf', 'isf'])
  289. @pytest.mark.parametrize('distname, args', distdiscrete)
  290. def test_methods_with_lists(method, distname, args):
  291. # Test that the discrete distributions can accept Python lists
  292. # as arguments.
  293. try:
  294. dist = getattr(stats, distname)
  295. except TypeError:
  296. return
  297. dist_method = getattr(dist, method)
  298. if method in ['ppf', 'isf']:
  299. z = [0.1, 0.2]
  300. else:
  301. z = [0, 1]
  302. p2 = [[p]*2 for p in args]
  303. loc = [0, 1]
  304. result = dist_method(z, *p2, loc=loc)
  305. npt.assert_allclose(result,
  306. [dist_method(*v) for v in zip(z, *p2, loc)],
  307. rtol=1e-15, atol=1e-15)
  308. @pytest.mark.parametrize('distname, args', invdistdiscrete)
  309. def test_cdf_gh13280_regression(distname, args):
  310. # Test for nan output when shape parameters are invalid
  311. dist = getattr(stats, distname)
  312. x = np.arange(-2, 15)
  313. vals = dist.cdf(x, *args)
  314. expected = np.nan
  315. npt.assert_equal(vals, expected)
  316. def cases_test_discrete_integer_shapes():
  317. # distributions parameters that are only allowed to be integral when
  318. # fitting, but are allowed to be real as input to PDF, etc.
  319. integrality_exceptions = {'nbinom': {'n'}, 'betanbinom': {'n'}}
  320. seen = set()
  321. for distname, shapes in distdiscrete:
  322. if distname in seen:
  323. continue
  324. seen.add(distname)
  325. try:
  326. dist = getattr(stats, distname)
  327. except TypeError:
  328. continue
  329. shape_info = dist._shape_info()
  330. for i, shape in enumerate(shape_info):
  331. if (shape.name in integrality_exceptions.get(distname, set()) or
  332. not shape.integrality):
  333. continue
  334. yield distname, shape.name, shapes
  335. @pytest.mark.parametrize('distname, shapename, shapes',
  336. cases_test_discrete_integer_shapes())
  337. def test_integer_shapes(distname, shapename, shapes):
  338. dist = getattr(stats, distname)
  339. shape_info = dist._shape_info()
  340. shape_names = [shape.name for shape in shape_info]
  341. i = shape_names.index(shapename) # this element of params must be integral
  342. shapes_copy = list(shapes)
  343. valid_shape = shapes[i]
  344. invalid_shape = valid_shape - 0.5 # arbitrary non-integral value
  345. new_valid_shape = valid_shape - 1
  346. shapes_copy[i] = [[valid_shape], [invalid_shape], [new_valid_shape]]
  347. a, b = dist.support(*shapes)
  348. x = np.round(np.linspace(a, b, 5))
  349. pmf = dist.pmf(x, *shapes_copy)
  350. assert not np.any(np.isnan(pmf[0, :]))
  351. assert np.all(np.isnan(pmf[1, :]))
  352. assert not np.any(np.isnan(pmf[2, :]))
  353. def test_frozen_attributes(monkeypatch):
  354. # gh-14827 reported that all frozen distributions had both pmf and pdf
  355. # attributes; continuous should have pdf and discrete should have pmf.
  356. message = "'rv_discrete_frozen' object has no attribute"
  357. with pytest.raises(AttributeError, match=message):
  358. stats.binom(10, 0.5).pdf
  359. with pytest.raises(AttributeError, match=message):
  360. stats.binom(10, 0.5).logpdf
  361. monkeypatch.setattr(stats.binom, "pdf", "herring", raising=False)
  362. frozen_binom = stats.binom(10, 0.5)
  363. assert isinstance(frozen_binom, rv_discrete_frozen)
  364. assert not hasattr(frozen_binom, "pdf")
  365. @pytest.mark.parametrize('distname, shapes', distdiscrete)
  366. def test_interval(distname, shapes):
  367. # gh-11026 reported that `interval` returns incorrect values when
  368. # `confidence=1`. The values were not incorrect, but it was not intuitive
  369. # that the left end of the interval should extend beyond the support of the
  370. # distribution. Confirm that this is the behavior for all distributions.
  371. if isinstance(distname, str):
  372. dist = getattr(stats, distname)
  373. else:
  374. dist = distname
  375. a, b = dist.support(*shapes)
  376. npt.assert_equal(dist.ppf([0, 1], *shapes), (a-1, b))
  377. npt.assert_equal(dist.isf([1, 0], *shapes), (a-1, b))
  378. npt.assert_equal(dist.interval(1, *shapes), (a-1, b))
  379. @pytest.mark.xfail_on_32bit("Sensible to machine precision")
  380. def test_rv_sample():
  381. # Thoroughly test rv_sample and check that gh-3758 is resolved
  382. # Generate a random discrete distribution
  383. rng = np.random.default_rng(98430143469)
  384. xk = np.sort(rng.random(10) * 10)
  385. pk = rng.random(10)
  386. pk /= np.sum(pk)
  387. dist = stats.rv_discrete(values=(xk, pk))
  388. # Generate points to the left and right of xk
  389. xk_left = (np.array([0] + xk[:-1].tolist()) + xk)/2
  390. xk_right = (np.array(xk[1:].tolist() + [xk[-1]+1]) + xk)/2
  391. # Generate points to the left and right of cdf
  392. cdf2 = np.cumsum(pk)
  393. cdf2_left = (np.array([0] + cdf2[:-1].tolist()) + cdf2)/2
  394. cdf2_right = (np.array(cdf2[1:].tolist() + [1]) + cdf2)/2
  395. # support - leftmost and rightmost xk
  396. a, b = dist.support()
  397. assert_allclose(a, xk[0])
  398. assert_allclose(b, xk[-1])
  399. # pmf - supported only on the xk
  400. assert_allclose(dist.pmf(xk), pk)
  401. assert_allclose(dist.pmf(xk_right), 0)
  402. assert_allclose(dist.pmf(xk_left), 0)
  403. # logpmf is log of the pmf; log(0) = -np.inf
  404. with np.errstate(divide='ignore'):
  405. assert_allclose(dist.logpmf(xk), np.log(pk))
  406. assert_allclose(dist.logpmf(xk_right), -np.inf)
  407. assert_allclose(dist.logpmf(xk_left), -np.inf)
  408. # cdf - the cumulative sum of the pmf
  409. assert_allclose(dist.cdf(xk), cdf2)
  410. assert_allclose(dist.cdf(xk_right), cdf2)
  411. assert_allclose(dist.cdf(xk_left), [0]+cdf2[:-1].tolist())
  412. with np.errstate(divide='ignore'):
  413. assert_allclose(dist.logcdf(xk), np.log(dist.cdf(xk)),
  414. atol=1e-15)
  415. assert_allclose(dist.logcdf(xk_right), np.log(dist.cdf(xk_right)),
  416. atol=1e-15)
  417. assert_allclose(dist.logcdf(xk_left), np.log(dist.cdf(xk_left)),
  418. atol=1e-15)
  419. # sf is 1-cdf
  420. assert_allclose(dist.sf(xk), 1-dist.cdf(xk))
  421. assert_allclose(dist.sf(xk_right), 1-dist.cdf(xk_right))
  422. assert_allclose(dist.sf(xk_left), 1-dist.cdf(xk_left))
  423. with np.errstate(divide='ignore'):
  424. assert_allclose(dist.logsf(xk), np.log(dist.sf(xk)),
  425. atol=1e-15)
  426. assert_allclose(dist.logsf(xk_right), np.log(dist.sf(xk_right)),
  427. atol=1e-15)
  428. assert_allclose(dist.logsf(xk_left), np.log(dist.sf(xk_left)),
  429. atol=1e-15)
  430. # ppf
  431. assert_allclose(dist.ppf(cdf2), xk)
  432. assert_allclose(dist.ppf(cdf2_left), xk)
  433. assert_allclose(dist.ppf(cdf2_right)[:-1], xk[1:])
  434. assert_allclose(dist.ppf(0), a - 1)
  435. assert_allclose(dist.ppf(1), b)
  436. # isf
  437. sf2 = dist.sf(xk)
  438. assert_allclose(dist.isf(sf2), xk)
  439. assert_allclose(dist.isf(1-cdf2_left), dist.ppf(cdf2_left))
  440. assert_allclose(dist.isf(1-cdf2_right), dist.ppf(cdf2_right))
  441. assert_allclose(dist.isf(0), b)
  442. assert_allclose(dist.isf(1), a - 1)
  443. # interval is (ppf(alpha/2), isf(alpha/2))
  444. ps = np.linspace(0.01, 0.99, 10)
  445. int2 = dist.ppf(ps/2), dist.isf(ps/2)
  446. assert_allclose(dist.interval(1-ps), int2)
  447. assert_allclose(dist.interval(0), dist.median())
  448. assert_allclose(dist.interval(1), (a-1, b))
  449. # median is simply ppf(0.5)
  450. med2 = dist.ppf(0.5)
  451. assert_allclose(dist.median(), med2)
  452. # all four stats (mean, var, skew, and kurtosis) from the definitions
  453. mean2 = np.sum(xk*pk)
  454. var2 = np.sum((xk - mean2)**2 * pk)
  455. skew2 = np.sum((xk - mean2)**3 * pk) / var2**(3/2)
  456. kurt2 = np.sum((xk - mean2)**4 * pk) / var2**2 - 3
  457. assert_allclose(dist.mean(), mean2)
  458. assert_allclose(dist.std(), np.sqrt(var2))
  459. assert_allclose(dist.var(), var2)
  460. assert_allclose(dist.stats(moments='mvsk'), (mean2, var2, skew2, kurt2))
  461. # noncentral moment against definition
  462. mom3 = np.sum((xk**3) * pk)
  463. assert_allclose(dist.moment(3), mom3)
  464. # expect - check against moments
  465. assert_allclose(dist.expect(lambda x: 1), 1)
  466. assert_allclose(dist.expect(), mean2)
  467. assert_allclose(dist.expect(lambda x: x**3), mom3)
  468. # entropy is the negative of the expected value of log(p)
  469. with np.errstate(divide='ignore'):
  470. assert_allclose(-dist.expect(lambda x: dist.logpmf(x)), dist.entropy())
  471. # RVS is just ppf of uniform random variates
  472. rng = np.random.default_rng(98430143469)
  473. rvs = dist.rvs(size=100, random_state=rng)
  474. rng = np.random.default_rng(98430143469)
  475. rvs0 = dist.ppf(rng.random(size=100))
  476. assert_allclose(rvs, rvs0)
  477. def test__pmf_float_input():
  478. # gh-21272
  479. # test that `rvs()` can be computed when `_pmf` requires float input
  480. class rv_exponential(stats.rv_discrete):
  481. def _pmf(self, i):
  482. return (2/3)*3**(1 - i)
  483. rv = rv_exponential(a=0.0, b=float('inf'))
  484. rvs = rv.rvs(random_state=42) # should not crash due to integer input to `_pmf`
  485. assert_allclose(rvs, 0)
  486. def test_gh18919_ppf_array_args():
  487. # gh-18919 reported incorrect results for ppf and isf of discrete distributions when
  488. # arguments were arrays and first argument (`q`) had elements at the boundaries of
  489. # the support.
  490. q = [[0.5, 1.0, 0.5],
  491. [1.0, 0.5, 1.0],
  492. [0.5, 1.0, 0.5]]
  493. n = [[45, 46, 47],
  494. [48, 49, 50],
  495. [51, 52, 53]]
  496. p = 0.5
  497. ref = _ufuncs._binom_ppf(q, n, p)
  498. res = stats.binom.ppf(q, n, p)
  499. np.testing.assert_allclose(res, ref)
  500. @pytest.mark.parametrize("dist", [stats.binom, stats.boltzmann])
  501. def test_gh18919_ppf_isf_array_args2(dist):
  502. # a more general version of the test above. Requires that arguments are broadcasted
  503. # by the infrastructure.
  504. rng = np.random.default_rng(34873457824358729823)
  505. q = rng.random(size=(30, 1, 1, 1))
  506. n = rng.integers(10, 30, size=(10, 1, 1))
  507. p = rng.random(size=(4, 1))
  508. loc = rng.integers(5, size=(3,))
  509. q[rng.random(size=30) > 0.7] = 0
  510. q[rng.random(size=30) > 0.7] = 1
  511. args = (q, n, p) if dist == stats.binom else (q, p, n)
  512. res = dist.ppf(*args, loc=loc)
  513. ref = np.vectorize(dist.ppf)(*args) + loc
  514. np.testing.assert_allclose(res, ref)
  515. res = dist.isf(*args, loc=loc)
  516. ref = np.vectorize(dist.isf)(*args) + loc
  517. np.testing.assert_allclose(res, ref)