test_continuous.py 91 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194
  1. import itertools as it
  2. import os
  3. import pickle
  4. from copy import deepcopy
  5. import numpy as np
  6. from numpy import inf
  7. import pytest
  8. from numpy.testing import assert_allclose, assert_equal
  9. from hypothesis import strategies, given, reproduce_failure, settings # noqa: F401
  10. import hypothesis.extra.numpy as npst
  11. from scipy import special
  12. from scipy import stats
  13. from scipy.stats._fit import _kolmogorov_smirnov
  14. from scipy.stats._ksstats import kolmogn
  15. from scipy.stats import qmc
  16. from scipy.stats._distr_params import distcont, distdiscrete
  17. from scipy.stats._distribution_infrastructure import (
  18. _Domain, _RealInterval, _Parameter, _Parameterization, _RealParameter,
  19. ContinuousDistribution, ShiftedScaledDistribution, _fiinfo,
  20. _generate_domain_support, Mixture)
  21. from scipy.stats._new_distributions import StandardNormal, _LogUniform, _Gamma
  22. from scipy.stats._new_distributions import DiscreteDistribution
  23. from scipy.stats import Normal, Logistic, Uniform, Binomial
  24. class Test_RealInterval:
  25. rng = np.random.default_rng(349849812549824)
  26. def test_iv(self):
  27. domain = _RealInterval(endpoints=('a', 'b'))
  28. message = "The endpoints of the distribution are defined..."
  29. with pytest.raises(TypeError, match=message):
  30. domain.get_numerical_endpoints(dict)
  31. @pytest.mark.parametrize('x', [rng.uniform(10, 10, size=(2, 3, 4)),
  32. -np.inf, np.pi])
  33. def test_contains_simple(self, x):
  34. # Test `contains` when endpoints are defined by constants
  35. a, b = -np.inf, np.pi
  36. domain = _RealInterval(endpoints=(a, b), inclusive=(False, True))
  37. assert_equal(domain.contains(x), (a < x) & (x <= b))
  38. @pytest.mark.slow
  39. @given(shapes=npst.mutually_broadcastable_shapes(num_shapes=3, min_side=0),
  40. inclusive_a=strategies.booleans(),
  41. inclusive_b=strategies.booleans(),
  42. data=strategies.data())
  43. def test_contains(self, shapes, inclusive_a, inclusive_b, data):
  44. # Test `contains` when endpoints are defined by parameters
  45. input_shapes, result_shape = shapes
  46. shape_a, shape_b, shape_x = input_shapes
  47. # Without defining min and max values, I spent forever trying to set
  48. # up a valid test without overflows or similar just drawing arrays.
  49. a_elements = dict(allow_nan=False, allow_infinity=False,
  50. min_value=-1e3, max_value=1)
  51. b_elements = dict(allow_nan=False, allow_infinity=False,
  52. min_value=2, max_value=1e3)
  53. a = data.draw(npst.arrays(npst.floating_dtypes(),
  54. shape_a, elements=a_elements))
  55. b = data.draw(npst.arrays(npst.floating_dtypes(),
  56. shape_b, elements=b_elements))
  57. # ensure some points are to the left, some to the right, and some
  58. # are exactly on the boundary
  59. d = b - a
  60. x = np.concatenate([np.linspace(a-d, a, 10),
  61. np.linspace(a, b, 10),
  62. np.linspace(b, b+d, 10)])
  63. # Domain is defined by two parameters, 'a' and 'b'
  64. domain = _RealInterval(endpoints=('a', 'b'),
  65. inclusive=(inclusive_a, inclusive_b))
  66. domain.define_parameters(_RealParameter('a', domain=_RealInterval()),
  67. _RealParameter('b', domain=_RealInterval()))
  68. # Check that domain and string evaluation give the same result
  69. res = domain.contains(x, dict(a=a, b=b))
  70. # Apparently, `np.float16([2]) < np.float32(2.0009766)` is False
  71. # but `np.float16([2]) < np.float32([2.0009766])` is True
  72. # dtype = np.result_type(a.dtype, b.dtype, x.dtype)
  73. # a, b, x = a.astype(dtype), b.astype(dtype), x.astype(dtype)
  74. # unclear whether we should be careful about this, since it will be
  75. # fixed with NEP50. Just do what makes the test pass.
  76. left_comparison = '<=' if inclusive_a else '<'
  77. right_comparison = '<=' if inclusive_b else '<'
  78. ref = eval(f'(a {left_comparison} x) & (x {right_comparison} b)')
  79. assert_equal(res, ref)
  80. @pytest.mark.parametrize("inclusive", list(it.product([True, False], repeat=2)))
  81. @pytest.mark.parametrize("a,b", [(0, 1), (3, 1)])
  82. def test_contains_function_endpoints(self, inclusive, a, b):
  83. # Test `contains` when endpoints are defined by functions.
  84. endpoints = (lambda a, b: (a - b) / 2, lambda a, b: (a + b) / 2)
  85. domain = _RealInterval(endpoints=endpoints, inclusive=inclusive)
  86. x = np.asarray([(a - 2*b)/2, (a - b)/2, a/2, (a + b)/2, (a + 2*b)/2])
  87. res = domain.contains(x, dict(a=a, b=b))
  88. numerical_endpoints = ((a - b) / 2, (a + b) / 2)
  89. assert numerical_endpoints == domain.get_numerical_endpoints(dict(a=a, b=b))
  90. alpha, beta = numerical_endpoints
  91. above_left = alpha <= x if inclusive[0] else alpha < x
  92. below_right = x <= beta if inclusive[1] else x < beta
  93. ref = above_left & below_right
  94. assert_equal(res, ref)
  95. @pytest.mark.parametrize('case', [
  96. (-np.inf, np.pi, False, True, r"(-\infty, \pi]"),
  97. ('a', 5, True, False, "[a, 5)")
  98. ])
  99. def test_str(self, case):
  100. domain = _RealInterval(endpoints=case[:2], inclusive=case[2:4])
  101. assert str(domain) == case[4]
  102. @pytest.mark.slow
  103. @given(a=strategies.one_of(
  104. strategies.decimals(allow_nan=False),
  105. strategies.characters(whitelist_categories="L"), # type: ignore[arg-type]
  106. strategies.sampled_from(list(_Domain.symbols))),
  107. b=strategies.one_of(
  108. strategies.decimals(allow_nan=False),
  109. strategies.characters(whitelist_categories="L"), # type: ignore[arg-type]
  110. strategies.sampled_from(list(_Domain.symbols))),
  111. inclusive_a=strategies.booleans(),
  112. inclusive_b=strategies.booleans(),
  113. )
  114. def test_str2(self, a, b, inclusive_a, inclusive_b):
  115. # I wrote this independently from the implementation of __str__, but
  116. # I imagine it looks pretty similar to __str__.
  117. a = _Domain.symbols.get(a, a)
  118. b = _Domain.symbols.get(b, b)
  119. left_bracket = '[' if inclusive_a else '('
  120. right_bracket = ']' if inclusive_b else ')'
  121. domain = _RealInterval(endpoints=(a, b),
  122. inclusive=(inclusive_a, inclusive_b))
  123. ref = f"{left_bracket}{a}, {b}{right_bracket}"
  124. assert str(domain) == ref
  125. def test_symbols_gh22137(self):
  126. # `symbols` was accidentally shared between instances originally
  127. # Check that this is no longer the case
  128. domain1 = _RealInterval(endpoints=(0, 1))
  129. domain2 = _RealInterval(endpoints=(0, 1))
  130. assert domain1.symbols is not domain2.symbols
  131. def draw_distribution_from_family(family, data, rng, proportions, min_side=0):
  132. # If the distribution has parameters, choose a parameterization and
  133. # draw broadcastable shapes for the parameter arrays.
  134. n_parameterizations = family._num_parameterizations()
  135. if n_parameterizations > 0:
  136. i = data.draw(strategies.integers(0, max_value=n_parameterizations-1))
  137. n_parameters = family._num_parameters(i)
  138. shapes, result_shape = data.draw(
  139. npst.mutually_broadcastable_shapes(num_shapes=n_parameters,
  140. min_side=min_side))
  141. dist = family._draw(shapes, rng=rng, proportions=proportions,
  142. i_parameterization=i)
  143. else:
  144. dist = family._draw(rng=rng)
  145. result_shape = tuple()
  146. # Draw a broadcastable shape for the arguments, and draw values for the
  147. # arguments.
  148. x_shape = data.draw(npst.broadcastable_shapes(result_shape,
  149. min_side=min_side))
  150. x = dist._variable.draw(x_shape, parameter_values=dist._parameters,
  151. proportions=proportions, rng=rng, region='typical')
  152. x_result_shape = np.broadcast_shapes(x_shape, result_shape)
  153. y_shape = data.draw(npst.broadcastable_shapes(x_result_shape,
  154. min_side=min_side))
  155. y = dist._variable.draw(y_shape, parameter_values=dist._parameters,
  156. proportions=proportions, rng=rng, region='typical')
  157. xy_result_shape = np.broadcast_shapes(y_shape, x_result_shape)
  158. p_domain = _RealInterval((0, 1), (True, True))
  159. p_var = _RealParameter('p', domain=p_domain)
  160. p = p_var.draw(x_shape, proportions=proportions, rng=rng)
  161. with np.errstate(divide='ignore', invalid='ignore'):
  162. logp = np.log(p)
  163. return dist, x, y, p, logp, result_shape, x_result_shape, xy_result_shape
  164. continuous_families = [
  165. StandardNormal,
  166. Normal,
  167. Logistic,
  168. Uniform,
  169. _LogUniform
  170. ]
  171. discrete_families = [
  172. Binomial,
  173. ]
  174. families = continuous_families + discrete_families
  175. class TestDistributions:
  176. @pytest.mark.fail_slow(60) # need to break up check_moment_funcs
  177. @settings(max_examples=20)
  178. @pytest.mark.parametrize('family', families)
  179. @given(data=strategies.data(), seed=strategies.integers(min_value=0))
  180. def test_support_moments_sample(self, family, data, seed):
  181. rng = np.random.default_rng(seed)
  182. # relative proportions of valid, endpoint, out of bounds, and NaN params
  183. proportions = (0.7, 0.1, 0.1, 0.1)
  184. tmp = draw_distribution_from_family(family, data, rng, proportions)
  185. dist, x, y, p, logp, result_shape, x_result_shape, xy_result_shape = tmp
  186. sample_shape = data.draw(npst.array_shapes(min_dims=0, min_side=0,
  187. max_side=20))
  188. with np.errstate(invalid='ignore', divide='ignore'):
  189. check_support(dist)
  190. check_moment_funcs(dist, result_shape) # this needs to get split up
  191. check_sample_shape_NaNs(dist, 'sample', sample_shape, result_shape, rng)
  192. qrng = qmc.Halton(d=1, seed=rng)
  193. check_sample_shape_NaNs(dist, 'sample', sample_shape, result_shape, qrng)
  194. @pytest.mark.fail_slow(10)
  195. @pytest.mark.parametrize('family', families)
  196. @pytest.mark.parametrize('func, methods, arg',
  197. [('entropy', {'log/exp', 'quadrature'}, None),
  198. ('logentropy', {'log/exp', 'quadrature'}, None),
  199. ('median', {'icdf'}, None),
  200. ('mode', {'optimization'}, None),
  201. ('mean', {'cache'}, None),
  202. ('variance', {'cache'}, None),
  203. ('skewness', {'cache'}, None),
  204. ('kurtosis', {'cache'}, None),
  205. ('pdf', {'log/exp'}, 'x'),
  206. ('logpdf', {'log/exp'}, 'x'),
  207. ('logcdf', {'log/exp', 'complement', 'quadrature'}, 'x'),
  208. ('cdf', {'log/exp', 'complement', 'quadrature'}, 'x'),
  209. ('logccdf', {'log/exp', 'complement', 'quadrature'}, 'x'),
  210. ('ccdf', {'log/exp', 'complement', 'quadrature'}, 'x'),
  211. ('ilogccdf', {'complement', 'inversion'}, 'logp'),
  212. ('iccdf', {'complement', 'inversion'}, 'p'),
  213. ])
  214. @settings(max_examples=20)
  215. @given(data=strategies.data(), seed=strategies.integers(min_value=0))
  216. def test_funcs(self, family, data, seed, func, methods, arg):
  217. if family == Uniform and func == 'mode':
  218. pytest.skip("Mode is not unique; `method`s disagree.")
  219. rng = np.random.default_rng(seed)
  220. # relative proportions of valid, endpoint, out of bounds, and NaN params
  221. proportions = (0.7, 0.1, 0.1, 0.1)
  222. tmp = draw_distribution_from_family(family, data, rng, proportions)
  223. dist, x, y, p, logp, result_shape, x_result_shape, xy_result_shape = tmp
  224. args = {'x': x, 'p': p, 'logp': p}
  225. with np.errstate(invalid='ignore', divide='ignore', over='ignore'):
  226. if arg is None:
  227. check_dist_func(dist, func, None, result_shape, methods)
  228. elif arg in args:
  229. check_dist_func(dist, func, args[arg], x_result_shape, methods)
  230. if func == 'variance':
  231. assert_allclose(dist.standard_deviation()**2, dist.variance())
  232. # invalid and divide are to be expected; maybe look into over
  233. with np.errstate(invalid='ignore', divide='ignore', over='ignore'):
  234. if not isinstance(dist, ShiftedScaledDistribution):
  235. if func == 'cdf':
  236. methods = {'quadrature'}
  237. check_cdf2(dist, False, x, y, xy_result_shape, methods)
  238. check_cdf2(dist, True, x, y, xy_result_shape, methods)
  239. elif func == 'ccdf':
  240. methods = {'addition'}
  241. check_ccdf2(dist, False, x, y, xy_result_shape, methods)
  242. check_ccdf2(dist, True, x, y, xy_result_shape, methods)
  243. def test_plot(self):
  244. try:
  245. import matplotlib.pyplot as plt
  246. except ImportError:
  247. return
  248. X = Uniform(a=0., b=1.)
  249. ax = X.plot()
  250. assert ax == plt.gca()
  251. @pytest.mark.parametrize('method_name', ['cdf', 'ccdf'])
  252. def test_complement_safe(self, method_name):
  253. X = stats.Normal()
  254. X.tol = 1e-12
  255. p = np.asarray([1e-4, 1e-3])
  256. func = getattr(X, method_name)
  257. ifunc = getattr(X, 'i'+method_name)
  258. x = ifunc(p, method='formula')
  259. p1 = func(x, method='complement_safe')
  260. p2 = func(x, method='complement')
  261. assert_equal(p1[1], p2[1])
  262. assert p1[0] != p2[0]
  263. assert_allclose(p1[0], p[0], rtol=X.tol)
  264. @pytest.mark.parametrize('method_name', ['cdf', 'ccdf'])
  265. def test_icomplement_safe(self, method_name):
  266. X = stats.Normal()
  267. X.tol = 1e-12
  268. p = np.asarray([1e-4, 1e-3])
  269. func = getattr(X, method_name)
  270. ifunc = getattr(X, 'i'+method_name)
  271. x1 = ifunc(p, method='complement_safe')
  272. x2 = ifunc(p, method='complement')
  273. assert_equal(x1[1], x2[1])
  274. assert x1[0] != x2[0]
  275. assert_allclose(func(x1[0]), p[0], rtol=X.tol)
  276. def test_subtraction_safe(self):
  277. X = stats.Normal()
  278. X.tol = 1e-12
  279. # Regular subtraction is fine in either tail (and of course, across tails)
  280. x = [-11, -10, 10, 11]
  281. y = [-10, -11, 11, 10]
  282. p0 = X.cdf(x, y, method='quadrature')
  283. p1 = X.cdf(x, y, method='subtraction_safe')
  284. p2 = X.cdf(x, y, method='subtraction')
  285. assert_equal(p2, p1)
  286. assert_allclose(p1, p0, rtol=X.tol)
  287. # Safe subtraction is needed in special cases
  288. x = np.asarray([-1e-20, -1e-21, 1e-20, 1e-21, -1e-20])
  289. y = np.asarray([-1e-21, -1e-20, 1e-21, 1e-20, 1e-20])
  290. p0 = X.pdf(0)*(y-x)
  291. p1 = X.cdf(x, y, method='subtraction_safe')
  292. p2 = X.cdf(x, y, method='subtraction')
  293. assert_equal(p2, 0)
  294. assert_allclose(p1, p0, rtol=X.tol)
  295. def test_logentropy_safe(self):
  296. # simulate an `entropy` calculation over/underflowing with extreme parameters
  297. class _Normal(stats.Normal):
  298. def _entropy_formula(self, **params):
  299. out = np.asarray(super()._entropy_formula(**params))
  300. out[0] = 0
  301. out[-1] = np.inf
  302. return out
  303. X = _Normal(sigma=[1, 2, 3])
  304. with np.errstate(divide='ignore'):
  305. res1 = X.logentropy(method='logexp_safe')
  306. res2 = X.logentropy(method='logexp')
  307. ref = X.logentropy(method='quadrature')
  308. i_fl = [0, -1] # first and last
  309. assert np.isinf(res2[i_fl]).all()
  310. assert res1[1] == res2[1]
  311. # quadrature happens to be perfectly accurate on some platforms
  312. # assert res1[1] != ref[1]
  313. assert_equal(res1[i_fl], ref[i_fl])
  314. def test_logcdf2_safe(self):
  315. # test what happens when 2-arg `cdf` underflows
  316. X = stats.Normal(sigma=[1, 2, 3])
  317. x = [-301, 1, 300]
  318. y = [-300, 2, 301]
  319. with np.errstate(divide='ignore'):
  320. res1 = X.logcdf(x, y, method='logexp_safe')
  321. res2 = X.logcdf(x, y, method='logexp')
  322. ref = X.logcdf(x, y, method='quadrature')
  323. i_fl = [0, -1] # first and last
  324. assert np.isinf(res2[i_fl]).all()
  325. assert res1[1] == res2[1]
  326. # quadrature happens to be perfectly accurate on some platforms
  327. # assert res1[1] != ref[1]
  328. assert_equal(res1[i_fl], ref[i_fl])
  329. @pytest.mark.parametrize('method_name', ['logcdf', 'logccdf'])
  330. def test_logexp_safe(self, method_name):
  331. # test what happens when `cdf`/`ccdf` underflows
  332. X = stats.Normal(sigma=2)
  333. x = [-301, 1] if method_name == 'logcdf' else [301, 1]
  334. func = getattr(X, method_name)
  335. with np.errstate(divide='ignore'):
  336. res1 = func(x, method='logexp_safe')
  337. res2 = func(x, method='logexp')
  338. ref = func(x, method='quadrature')
  339. assert res1[0] == ref[0]
  340. assert res1[0] != res2[0]
  341. assert res1[1] == res2[1]
  342. assert res1[1] != ref[1]
  343. def check_sample_shape_NaNs(dist, fname, sample_shape, result_shape, rng):
  344. full_shape = sample_shape + result_shape
  345. if fname == 'sample':
  346. sample_method = dist.sample
  347. methods = {'inverse_transform'}
  348. if dist._overrides(f'_{fname}_formula') and not isinstance(rng, qmc.QMCEngine):
  349. methods.add('formula')
  350. for method in methods:
  351. res = sample_method(sample_shape, method=method, rng=rng)
  352. valid_parameters = np.broadcast_to(get_valid_parameters(dist),
  353. res.shape)
  354. assert_equal(res.shape, full_shape)
  355. np.testing.assert_equal(res.dtype, dist._dtype)
  356. if full_shape == ():
  357. # NumPy random makes a distinction between a 0d array and a scalar.
  358. # In stats, we consistently turn 0d arrays into scalars, so
  359. # maintain that behavior here. (With Array API arrays, this will
  360. # change.)
  361. assert np.isscalar(res)
  362. assert np.all(np.isfinite(res[valid_parameters]))
  363. assert_equal(res[~valid_parameters], np.nan)
  364. sample1 = sample_method(sample_shape, method=method, rng=42)
  365. sample2 = sample_method(sample_shape, method=method, rng=42)
  366. if not isinstance(dist, DiscreteDistribution):
  367. # The idea is that it's very unlikely that the random sample
  368. # for a randomly chosen seed will match that for seed 42,
  369. # but it is not so unlikely if `dist` is a discrete distribution.
  370. assert not np.any(np.equal(res, sample1))
  371. assert_equal(sample1, sample2)
  372. def check_support(dist):
  373. a, b = dist.support()
  374. check_nans_and_edges(dist, 'support', None, a)
  375. check_nans_and_edges(dist, 'support', None, b)
  376. assert a.shape == dist._shape
  377. assert b.shape == dist._shape
  378. assert a.dtype == dist._dtype
  379. assert b.dtype == dist._dtype
  380. def check_dist_func(dist, fname, arg, result_shape, methods):
  381. # Check that all computation methods of all distribution functions agree
  382. # with one another, effectively testing the correctness of the generic
  383. # computation methods and confirming the consistency of specific
  384. # distributions with their pdf/logpdf.
  385. args = tuple() if arg is None else (arg,)
  386. methods = methods.copy()
  387. if "cache" in methods:
  388. # If "cache" is specified before the value has been evaluated, it
  389. # raises an error. After the value is evaluated, it will succeed.
  390. with pytest.raises(NotImplementedError):
  391. getattr(dist, fname)(*args, method="cache")
  392. ref = getattr(dist, fname)(*args)
  393. check_nans_and_edges(dist, fname, arg, ref)
  394. # Remove this after fixing `draw`
  395. tol_override = {'atol': 1e-15}
  396. # Mean can be 0, which makes logmean -inf.
  397. if fname in {'logmean', 'mean', 'logskewness', 'skewness'}:
  398. tol_override = {'atol': 1e-15}
  399. elif fname in {'mode'}:
  400. # can only expect about half of machine precision for optimization
  401. # because math
  402. tol_override = {'atol': 1e-6}
  403. elif fname in {'logcdf'}: # gh-22276
  404. tol_override = {'rtol': 2e-7}
  405. if dist._overrides(f'_{fname}_formula'):
  406. methods.add('formula')
  407. np.testing.assert_equal(ref.shape, result_shape)
  408. # Until we convert to array API, let's do the familiar thing:
  409. # 0d things are scalars, not arrays
  410. if result_shape == tuple():
  411. assert np.isscalar(ref)
  412. for method in methods:
  413. res = getattr(dist, fname)(*args, method=method)
  414. if 'log' in fname:
  415. np.testing.assert_allclose(np.exp(res), np.exp(ref),
  416. **tol_override)
  417. else:
  418. np.testing.assert_allclose(res, ref, **tol_override)
  419. # for now, make sure dtypes are consistent; later, we can check whether
  420. # they are correct.
  421. np.testing.assert_equal(res.dtype, ref.dtype)
  422. np.testing.assert_equal(res.shape, result_shape)
  423. if result_shape == tuple():
  424. assert np.isscalar(res)
  425. def check_cdf2(dist, log, x, y, result_shape, methods):
  426. # Specialized test for 2-arg cdf since the interface is a bit different
  427. # from the other methods. Here, we'll use 1-arg cdf as a reference, and
  428. # since we have already checked 1-arg cdf in `check_nans_and_edges`, this
  429. # checks the equivalent of both `check_dist_func` and
  430. # `check_nans_and_edges`.
  431. methods = methods.copy()
  432. if log:
  433. if dist._overrides('_logcdf2_formula'):
  434. methods.add('formula')
  435. if dist._overrides('_logcdf_formula') or dist._overrides('_logccdf_formula'):
  436. methods.add('subtraction')
  437. if (dist._overrides('_cdf_formula')
  438. or dist._overrides('_ccdf_formula')):
  439. methods.add('log/exp')
  440. else:
  441. if dist._overrides('_cdf2_formula'):
  442. methods.add('formula')
  443. if dist._overrides('_cdf_formula') or dist._overrides('_ccdf_formula'):
  444. methods.add('subtraction')
  445. if (dist._overrides('_logcdf_formula')
  446. or dist._overrides('_logccdf_formula')):
  447. methods.add('log/exp')
  448. ref = dist.cdf(y) - dist.cdf(x)
  449. np.testing.assert_equal(ref.shape, result_shape)
  450. if result_shape == tuple():
  451. assert np.isscalar(ref)
  452. for method in methods:
  453. if isinstance(dist, DiscreteDistribution):
  454. message = ("Two argument cdf functions are currently only supported for "
  455. "continuous distributions.")
  456. with pytest.raises(NotImplementedError, match=message):
  457. res = (np.exp(dist.logcdf(x, y, method=method)) if log
  458. else dist.cdf(x, y, method=method))
  459. continue
  460. res = (np.exp(dist.logcdf(x, y, method=method)) if log
  461. else dist.cdf(x, y, method=method))
  462. np.testing.assert_allclose(res, ref, atol=1e-14)
  463. if log:
  464. np.testing.assert_equal(res.dtype, (ref + 0j).dtype)
  465. else:
  466. np.testing.assert_equal(res.dtype, ref.dtype)
  467. np.testing.assert_equal(res.shape, result_shape)
  468. if result_shape == tuple():
  469. assert np.isscalar(res)
  470. def check_ccdf2(dist, log, x, y, result_shape, methods):
  471. # Specialized test for 2-arg ccdf since the interface is a bit different
  472. # from the other methods. Could be combined with check_cdf2 above, but
  473. # writing it separately is simpler.
  474. methods = methods.copy()
  475. if dist._overrides(f'_{"log" if log else ""}ccdf2_formula'):
  476. methods.add('formula')
  477. ref = dist.cdf(x) + dist.ccdf(y)
  478. np.testing.assert_equal(ref.shape, result_shape)
  479. if result_shape == tuple():
  480. assert np.isscalar(ref)
  481. for method in methods:
  482. message = ("Two argument cdf functions are currently only supported for "
  483. "continuous distributions.")
  484. if isinstance(dist, DiscreteDistribution):
  485. with pytest.raises(NotImplementedError, match=message):
  486. res = (np.exp(dist.logccdf(x, y, method=method)) if log
  487. else dist.ccdf(x, y, method=method))
  488. continue
  489. res = (np.exp(dist.logccdf(x, y, method=method)) if log
  490. else dist.ccdf(x, y, method=method))
  491. np.testing.assert_allclose(res, ref, atol=1e-14)
  492. np.testing.assert_equal(res.dtype, ref.dtype)
  493. np.testing.assert_equal(res.shape, result_shape)
  494. if result_shape == tuple():
  495. assert np.isscalar(res)
  496. def check_nans_and_edges(dist, fname, arg, res):
  497. valid_parameters = get_valid_parameters(dist)
  498. if fname in {'icdf', 'iccdf'}:
  499. arg_domain = _RealInterval(endpoints=(0, 1), inclusive=(True, True))
  500. elif fname in {'ilogcdf', 'ilogccdf'}:
  501. arg_domain = _RealInterval(endpoints=(-inf, 0), inclusive=(True, True))
  502. else:
  503. arg_domain = dist._variable.domain
  504. classified_args = classify_arg(dist, arg, arg_domain)
  505. valid_parameters, *classified_args = np.broadcast_arrays(valid_parameters,
  506. *classified_args)
  507. valid_arg, endpoint_arg, outside_arg, nan_arg = classified_args
  508. all_valid = valid_arg & valid_parameters
  509. # Check NaN pattern and edge cases
  510. assert_equal(res[~valid_parameters], np.nan)
  511. assert_equal(res[nan_arg], np.nan)
  512. a, b = dist.support()
  513. a = np.broadcast_to(a, res.shape)
  514. b = np.broadcast_to(b, res.shape)
  515. outside_arg_minus = (outside_arg == -1) & valid_parameters
  516. outside_arg_plus = (outside_arg == 1) & valid_parameters
  517. endpoint_arg_minus = (endpoint_arg == -1) & valid_parameters
  518. endpoint_arg_plus = (endpoint_arg == 1) & valid_parameters
  519. is_discrete = isinstance(dist, DiscreteDistribution)
  520. # Writing this independently of how the are set in the distribution
  521. # infrastructure. That is very compact; this is very verbose.
  522. if fname in {'logpdf'}:
  523. assert_equal(res[outside_arg_minus], -np.inf)
  524. assert_equal(res[outside_arg_plus], -np.inf)
  525. ref = -np.inf if not is_discrete else np.inf
  526. assert_equal(res[endpoint_arg_minus & ~valid_arg], ref)
  527. assert_equal(res[endpoint_arg_plus & ~valid_arg], ref)
  528. elif fname in {'pdf'}:
  529. assert_equal(res[outside_arg_minus], 0)
  530. assert_equal(res[outside_arg_plus], 0)
  531. ref = 0 if not is_discrete else np.inf
  532. assert_equal(res[endpoint_arg_minus & ~valid_arg], ref)
  533. assert_equal(res[endpoint_arg_plus & ~valid_arg], ref)
  534. elif fname in {'logcdf'} and not is_discrete:
  535. assert_equal(res[outside_arg_minus], -inf)
  536. assert_equal(res[outside_arg_plus], 0)
  537. assert_equal(res[endpoint_arg_minus], -inf)
  538. assert_equal(res[endpoint_arg_plus], 0)
  539. elif fname in {'cdf'} and not is_discrete:
  540. assert_equal(res[outside_arg_minus], 0)
  541. assert_equal(res[outside_arg_plus], 1)
  542. assert_equal(res[endpoint_arg_minus], 0)
  543. assert_equal(res[endpoint_arg_plus], 1)
  544. elif fname in {'logccdf'} and not is_discrete:
  545. assert_equal(res[outside_arg_minus], 0)
  546. assert_equal(res[outside_arg_plus], -inf)
  547. assert_equal(res[endpoint_arg_minus], 0)
  548. assert_equal(res[endpoint_arg_plus], -inf)
  549. elif fname in {'ccdf'} and not is_discrete:
  550. assert_equal(res[outside_arg_minus], 1)
  551. assert_equal(res[outside_arg_plus], 0)
  552. assert_equal(res[endpoint_arg_minus], 1)
  553. assert_equal(res[endpoint_arg_plus], 0)
  554. elif fname in {'ilogcdf', 'icdf'} and not is_discrete:
  555. assert_equal(res[outside_arg == -1], np.nan)
  556. assert_equal(res[outside_arg == 1], np.nan)
  557. assert_equal(res[endpoint_arg == -1], a[endpoint_arg == -1])
  558. assert_equal(res[endpoint_arg == 1], b[endpoint_arg == 1])
  559. elif fname in {'ilogccdf', 'iccdf'} and not is_discrete:
  560. assert_equal(res[outside_arg == -1], np.nan)
  561. assert_equal(res[outside_arg == 1], np.nan)
  562. assert_equal(res[endpoint_arg == -1], b[endpoint_arg == -1])
  563. assert_equal(res[endpoint_arg == 1], a[endpoint_arg == 1])
  564. exclude = {'logmean', 'mean', 'logskewness', 'skewness', 'support'}
  565. if isinstance(dist, DiscreteDistribution):
  566. exclude.update({'pdf', 'logpdf'})
  567. if fname not in exclude:
  568. assert np.isfinite(res[all_valid & (endpoint_arg == 0)]).all()
  569. def check_moment_funcs(dist, result_shape):
  570. # Check that all computation methods of all distribution functions agree
  571. # with one another, effectively testing the correctness of the generic
  572. # computation methods and confirming the consistency of specific
  573. # distributions with their pdf/logpdf.
  574. atol = 1e-9 # make this tighter (e.g. 1e-13) after fixing `draw`
  575. def check(order, kind, method=None, ref=None, success=True):
  576. if success:
  577. res = dist.moment(order, kind, method=method)
  578. assert_allclose(res, ref, atol=atol*10**order)
  579. assert res.shape == ref.shape
  580. else:
  581. with pytest.raises(NotImplementedError):
  582. dist.moment(order, kind, method=method)
  583. def has_formula(order, kind):
  584. formula_name = f'_moment_{kind}_formula'
  585. overrides = dist._overrides(formula_name)
  586. if not overrides:
  587. return False
  588. formula = getattr(dist, formula_name)
  589. orders = getattr(formula, 'orders', set(range(6)))
  590. return order in orders
  591. dist.reset_cache()
  592. ### Check Raw Moments ###
  593. for i in range(6):
  594. check(i, 'raw', 'cache', success=False) # not cached yet
  595. ref = dist.moment(i, 'raw', method='quadrature')
  596. check_nans_and_edges(dist, 'moment', None, ref)
  597. assert ref.shape == result_shape
  598. check(i, 'raw','cache', ref, success=True) # cached now
  599. check(i, 'raw', 'formula', ref, success=has_formula(i, 'raw'))
  600. check(i, 'raw', 'general', ref, success=(i == 0))
  601. if dist.__class__ == stats.Normal:
  602. check(i, 'raw', 'quadrature_icdf', ref, success=True)
  603. # Clearing caches to better check their behavior
  604. dist.reset_cache()
  605. # If we have central or standard moment formulas, or if there are
  606. # values in their cache, we can use method='transform'
  607. dist.moment(0, 'central') # build up the cache
  608. dist.moment(1, 'central')
  609. for i in range(2, 6):
  610. ref = dist.moment(i, 'raw', method='quadrature')
  611. check(i, 'raw', 'transform', ref,
  612. success=has_formula(i, 'central') or has_formula(i, 'standardized'))
  613. dist.moment(i, 'central') # build up the cache
  614. check(i, 'raw', 'transform', ref)
  615. dist.reset_cache()
  616. ### Check Central Moments ###
  617. for i in range(6):
  618. check(i, 'central', 'cache', success=False)
  619. ref = dist.moment(i, 'central', method='quadrature')
  620. assert ref.shape == result_shape
  621. check(i, 'central', 'cache', ref, success=True)
  622. check(i, 'central', 'formula', ref, success=has_formula(i, 'central'))
  623. check(i, 'central', 'general', ref, success=i <= 1)
  624. if dist.__class__ == stats.Normal:
  625. check(i, 'central', 'quadrature_icdf', ref, success=True)
  626. if not (dist.__class__ == stats.Uniform and i == 5):
  627. # Quadrature is not super accurate for 5th central moment when the
  628. # support is really big. Skip this one failing test. We need to come
  629. # up with a better system of skipping individual failures w/ hypothesis.
  630. check(i, 'central', 'transform', ref,
  631. success=has_formula(i, 'raw') or (i <= 1))
  632. if not has_formula(i, 'raw'):
  633. dist.moment(i, 'raw')
  634. check(i, 'central', 'transform', ref)
  635. variance = dist.variance()
  636. dist.reset_cache()
  637. # If we have standard moment formulas, or if there are
  638. # values in their cache, we can use method='normalize'
  639. dist.moment(0, 'standardized') # build up the cache
  640. dist.moment(1, 'standardized')
  641. dist.moment(2, 'standardized')
  642. for i in range(3, 6):
  643. ref = dist.moment(i, 'central', method='quadrature')
  644. check(i, 'central', 'normalize', ref,
  645. success=has_formula(i, 'standardized') and not np.any(variance == 0))
  646. dist.moment(i, 'standardized') # build up the cache
  647. check(i, 'central', 'normalize', ref, success=not np.any(variance == 0))
  648. ### Check Standardized Moments ###
  649. var = dist.moment(2, 'central', method='quadrature')
  650. dist.reset_cache()
  651. for i in range(6):
  652. check(i, 'standardized', 'cache', success=False)
  653. ref = dist.moment(i, 'central', method='quadrature') / var ** (i / 2)
  654. assert ref.shape == result_shape
  655. check(i, 'standardized', 'formula', ref,
  656. success=has_formula(i, 'standardized'))
  657. check(i, 'standardized', 'general', ref, success=i <= 2)
  658. check(i, 'standardized', 'normalize', ref)
  659. if isinstance(dist, ShiftedScaledDistribution):
  660. # logmoment is not fully fleshed out; no need to test
  661. # ShiftedScaledDistribution here
  662. return
  663. # logmoment is not very accuate, and it's not public, so skip for now
  664. # ### Check Against _logmoment ###
  665. # logmean = dist._logmoment(1, logcenter=-np.inf)
  666. # for i in range(6):
  667. # ref = np.exp(dist._logmoment(i, logcenter=-np.inf))
  668. # assert_allclose(dist.moment(i, 'raw'), ref, atol=atol*10**i)
  669. #
  670. # ref = np.exp(dist._logmoment(i, logcenter=logmean))
  671. # assert_allclose(dist.moment(i, 'central'), ref, atol=atol*10**i)
  672. #
  673. # ref = np.exp(dist._logmoment(i, logcenter=logmean, standardized=True))
  674. # assert_allclose(dist.moment(i, 'standardized'), ref, atol=atol*10**i)
  675. @pytest.mark.parametrize('family', (Normal,))
  676. @pytest.mark.parametrize('x_shape', [tuple(), (2, 3)])
  677. @pytest.mark.parametrize('dist_shape', [tuple(), (4, 1)])
  678. @pytest.mark.parametrize('fname', ['sample'])
  679. @pytest.mark.parametrize('rng_type', [np.random.Generator, qmc.Halton, qmc.Sobol])
  680. def test_sample_against_cdf(family, dist_shape, x_shape, fname, rng_type):
  681. rng = np.random.default_rng(842582438235635)
  682. num_parameters = family._num_parameters()
  683. if dist_shape and num_parameters == 0:
  684. pytest.skip("Distribution can't have a shape without parameters.")
  685. dist = family._draw(dist_shape, rng)
  686. n = 1024
  687. sample_size = (n,) + x_shape
  688. sample_array_shape = sample_size + dist_shape
  689. if fname == 'sample':
  690. sample_method = dist.sample
  691. if rng_type != np.random.Generator:
  692. rng = rng_type(d=1, seed=rng)
  693. x = sample_method(sample_size, rng=rng)
  694. assert x.shape == sample_array_shape
  695. # probably should give `axis` argument to ks_1samp, review that separately
  696. statistic = _kolmogorov_smirnov(dist, x, axis=0)
  697. pvalue = kolmogn(x.shape[0], statistic, cdf=False)
  698. p_threshold = 0.01
  699. num_pvalues = pvalue.size
  700. num_small_pvalues = np.sum(pvalue < p_threshold)
  701. assert num_small_pvalues < p_threshold * num_pvalues
  702. def get_valid_parameters(dist):
  703. # Given a distribution, return a logical array that is true where all
  704. # distribution parameters are within their respective domains. The code
  705. # here is probably quite similar to that used to form the `_invalid`
  706. # attribute of the distribution, but this was written about a week later
  707. # without referring to that code, so it is a somewhat independent check.
  708. # Get all parameter values and `_Parameter` objects
  709. parameter_values = dist._parameters
  710. parameters = {}
  711. for parameterization in dist._parameterizations:
  712. parameters.update(parameterization.parameters)
  713. all_valid = np.ones(dist._shape, dtype=bool)
  714. for name, value in parameter_values.items():
  715. if name not in parameters: # cached value not part of parameterization
  716. continue
  717. parameter = parameters[name]
  718. # Check that the numerical endpoints and inclusivity attribute
  719. # agree with the `contains` method about which parameter values are
  720. # within the domain.
  721. a, b = parameter.domain.get_numerical_endpoints(
  722. parameter_values=parameter_values)
  723. a_included, b_included = parameter.domain.inclusive
  724. valid = (a <= value) if a_included else a < value
  725. valid &= (value <= b) if b_included else value < b
  726. assert_equal(valid, parameter.domain.contains(
  727. value, parameter_values=parameter_values))
  728. # Form `all_valid` mask that is True where *all* parameters are valid
  729. all_valid &= valid
  730. # Check that the `all_valid` mask formed here is the complement of the
  731. # `dist._invalid` mask stored by the infrastructure
  732. assert_equal(~all_valid, dist._invalid)
  733. return all_valid
  734. def classify_arg(dist, arg, arg_domain):
  735. if arg is None:
  736. valid_args = np.ones(dist._shape, dtype=bool)
  737. endpoint_args = np.zeros(dist._shape, dtype=bool)
  738. outside_args = np.zeros(dist._shape, dtype=bool)
  739. nan_args = np.zeros(dist._shape, dtype=bool)
  740. return valid_args, endpoint_args, outside_args, nan_args
  741. a, b = arg_domain.get_numerical_endpoints(
  742. parameter_values=dist._parameters)
  743. a, b, arg = np.broadcast_arrays(a, b, arg)
  744. a_included, b_included = arg_domain.inclusive
  745. inside = (a <= arg) if a_included else a < arg
  746. inside &= (arg <= b) if b_included else arg < b
  747. # TODO: add `supported` method and check here
  748. on = np.zeros(a.shape, dtype=int)
  749. on[a == arg] = -1
  750. on[b == arg] = 1
  751. outside = np.zeros(a.shape, dtype=int)
  752. outside[(arg < a) if a_included else arg <= a] = -1
  753. outside[(b < arg) if b_included else b <= arg] = 1
  754. nan = np.isnan(arg)
  755. return inside, on, outside, nan
  756. def test_input_validation():
  757. class Test(ContinuousDistribution):
  758. _variable = _RealParameter('x', domain=_RealInterval())
  759. message = ("The `Test` distribution family does not accept parameters, "
  760. "but parameters `{'a'}` were provided.")
  761. with pytest.raises(ValueError, match=message):
  762. Test(a=1, )
  763. message = "Attribute `tol` of `Test` must be a positive float, if specified."
  764. with pytest.raises(ValueError, match=message):
  765. Test(tol=np.asarray([]))
  766. with pytest.raises(ValueError, match=message):
  767. Test(tol=[1, 2, 3])
  768. with pytest.raises(ValueError, match=message):
  769. Test(tol=np.nan)
  770. with pytest.raises(ValueError, match=message):
  771. Test(tol=-1)
  772. message = ("Argument `order` of `Test.moment` must be a "
  773. "finite, positive integer.")
  774. with pytest.raises(ValueError, match=message):
  775. Test().moment(-1)
  776. with pytest.raises(ValueError, match=message):
  777. Test().moment(np.inf)
  778. message = "Argument `kind` of `Test.moment` must be one of..."
  779. with pytest.raises(ValueError, match=message):
  780. Test().moment(2, kind='coconut')
  781. class Test2(ContinuousDistribution):
  782. _p1 = _RealParameter('c', domain=_RealInterval())
  783. _p2 = _RealParameter('d', domain=_RealInterval())
  784. _parameterizations = [_Parameterization(_p1, _p2)]
  785. _variable = _RealParameter('x', domain=_RealInterval())
  786. message = ("The provided parameters `{a}` do not match a supported "
  787. "parameterization of the `Test2` distribution family.")
  788. with pytest.raises(ValueError, match=message):
  789. Test2(a=1)
  790. message = ("The `Test2` distribution family requires parameters, but none "
  791. "were provided.")
  792. with pytest.raises(ValueError, match=message):
  793. Test2()
  794. message = ("The parameters `{c, d}` provided to the `Test2` "
  795. "distribution family cannot be broadcast to the same shape.")
  796. with pytest.raises(ValueError, match=message):
  797. Test2(c=[1, 2], d=[1, 2, 3])
  798. message = ("The argument provided to `Test2.pdf` cannot be be broadcast to "
  799. "the same shape as the distribution parameters.")
  800. with pytest.raises(ValueError, match=message):
  801. dist = Test2(c=[1, 2, 3], d=[1, 2, 3])
  802. dist.pdf([1, 2])
  803. message = "Parameter `c` must be of real dtype."
  804. with pytest.raises(TypeError, match=message):
  805. Test2(c=[1, object()], d=[1, 2])
  806. message = "Parameter `convention` of `Test2.kurtosis` must be one of..."
  807. with pytest.raises(ValueError, match=message):
  808. dist = Test2(c=[1, 2, 3], d=[1, 2, 3])
  809. dist.kurtosis(convention='coconut')
  810. def test_rng_deepcopy_pickle():
  811. # test behavior of `rng` attribute and copy behavior
  812. kwargs = dict(a=[-1, 2], b=10)
  813. dist1 = Uniform(**kwargs)
  814. dist2 = deepcopy(dist1)
  815. dist3 = pickle.loads(pickle.dumps(dist1))
  816. res1, res2, res3 = dist1.sample(), dist2.sample(), dist3.sample()
  817. assert np.all(res2 != res1)
  818. assert np.all(res3 != res1)
  819. res1, res2, res3 = dist1.sample(rng=42), dist2.sample(rng=42), dist3.sample(rng=42)
  820. assert np.all(res2 == res1)
  821. assert np.all(res3 == res1)
  822. class TestAttributes:
  823. def test_cache_policy(self):
  824. dist = StandardNormal(cache_policy="no_cache")
  825. # make error message more appropriate
  826. message = "`StandardNormal` does not provide an accurate implementation of the "
  827. with pytest.raises(NotImplementedError, match=message):
  828. dist.mean(method='cache')
  829. mean = dist.mean()
  830. with pytest.raises(NotImplementedError, match=message):
  831. dist.mean(method='cache')
  832. # add to enum
  833. dist.cache_policy = None
  834. with pytest.raises(NotImplementedError, match=message):
  835. dist.mean(method='cache')
  836. mean = dist.mean() # method is 'formula' by default
  837. cached_mean = dist.mean(method='cache')
  838. assert_equal(cached_mean, mean)
  839. # cache is overridden by latest evaluation
  840. quadrature_mean = dist.mean(method='quadrature')
  841. cached_mean = dist.mean(method='cache')
  842. assert_equal(cached_mean, quadrature_mean)
  843. assert not np.all(mean == quadrature_mean)
  844. # We can turn the cache off, and it won't change, but the old cache is
  845. # still available
  846. dist.cache_policy = "no_cache"
  847. mean = dist.mean(method='formula')
  848. cached_mean = dist.mean(method='cache')
  849. assert_equal(cached_mean, quadrature_mean)
  850. assert not np.all(mean == quadrature_mean)
  851. dist.reset_cache()
  852. with pytest.raises(NotImplementedError, match=message):
  853. dist.mean(method='cache')
  854. message = "Attribute `cache_policy` of `StandardNormal`..."
  855. with pytest.raises(ValueError, match=message):
  856. dist.cache_policy = "invalid"
  857. def test_tol(self):
  858. x = 3.
  859. X = stats.Normal()
  860. message = "Attribute `tol` of `StandardNormal` must..."
  861. with pytest.raises(ValueError, match=message):
  862. X.tol = -1.
  863. with pytest.raises(ValueError, match=message):
  864. X.tol = (0.1,)
  865. with pytest.raises(ValueError, match=message):
  866. X.tol = np.nan
  867. X1 = stats.Normal(tol=1e-1)
  868. X2 = stats.Normal(tol=1e-12)
  869. ref = X.cdf(x)
  870. res1 = X1.cdf(x, method='quadrature')
  871. res2 = X2.cdf(x, method='quadrature')
  872. assert_allclose(res1, ref, rtol=X1.tol)
  873. assert_allclose(res2, ref, rtol=X2.tol)
  874. assert abs(res1 - ref) > abs(res2 - ref)
  875. p = 0.99
  876. X1.tol, X2.tol = X2.tol, X1.tol
  877. ref = X.icdf(p)
  878. res1 = X1.icdf(p, method='inversion')
  879. res2 = X2.icdf(p, method='inversion')
  880. assert_allclose(res1, ref, rtol=X1.tol)
  881. assert_allclose(res2, ref, rtol=X2.tol)
  882. assert abs(res2 - ref) > abs(res1 - ref)
  883. def test_iv_policy(self):
  884. X = Uniform(a=0, b=1)
  885. assert X.pdf(2) == 0
  886. X.validation_policy = 'skip_all'
  887. assert X.pdf(np.asarray(2.)) == 1
  888. # Tests _set_invalid_nan
  889. a, b = np.asarray(1.), np.asarray(0.) # invalid parameters
  890. X = Uniform(a=a, b=b, validation_policy='skip_all')
  891. assert X.pdf(np.asarray(2.)) == -1
  892. # Tests _set_invalid_nan_property
  893. class MyUniform(Uniform):
  894. def _entropy_formula(self, *args, **kwargs):
  895. return 'incorrect'
  896. def _moment_raw_formula(self, order, **params):
  897. return 'incorrect'
  898. X = MyUniform(a=a, b=b, validation_policy='skip_all')
  899. assert X.entropy() == 'incorrect'
  900. # Tests _validate_order_kind
  901. assert X.moment(kind='raw', order=-1) == 'incorrect'
  902. # Test input validation
  903. message = "Attribute `validation_policy` of `MyUniform`..."
  904. with pytest.raises(ValueError, match=message):
  905. X.validation_policy = "invalid"
  906. def test_shapes(self):
  907. X = stats.Normal(mu=1, sigma=2)
  908. Y = stats.Normal(mu=[2], sigma=3)
  909. # Check that attributes are available as expected
  910. assert X.mu == 1
  911. assert X.sigma == 2
  912. assert Y.mu[0] == 2
  913. assert Y.sigma[0] == 3
  914. # Trying to set an attribute raises
  915. # message depends on Python version
  916. with pytest.raises(AttributeError):
  917. X.mu = 2
  918. # Trying to mutate an attribute really mutates a copy
  919. Y.mu[0] = 10
  920. assert Y.mu[0] == 2
  921. class TestMakeDistribution:
  922. @pytest.mark.parametrize('i, distdata', enumerate(distcont + distdiscrete))
  923. def test_rv_generic(self, i, distdata):
  924. distname = distdata[0]
  925. slow = {'argus', 'exponpow', 'exponweib', 'genexpon', 'gompertz', 'halfgennorm',
  926. 'johnsonsb', 'kappa4', 'ksone', 'kstwo', 'kstwobign', 'norminvgauss',
  927. 'powerlognorm', 'powernorm', 'recipinvgauss', 'studentized_range',
  928. 'vonmises_line', # continuous
  929. 'betanbinom', 'logser', 'zipf'} # discrete
  930. if not int(os.environ.get('SCIPY_XSLOW', '0')) and distname in slow:
  931. pytest.skip('Skipping as XSLOW')
  932. if distname in { # skip these distributions
  933. 'levy_stable', # private methods seem to require >= 1d args
  934. 'vonmises', # circular distribution; shouldn't work
  935. 'poisson_binom', # vector shape parameter
  936. 'hypergeom', # distribution functions need interpolation
  937. 'nchypergeom_fisher', # distribution functions need interpolation
  938. 'nchypergeom_wallenius', # distribution functions need interpolation
  939. }:
  940. return
  941. # skip single test, mostly due to slight disagreement
  942. custom_tolerances = {'ksone': 1e-5, 'kstwo': 1e-5} # discontinuous PDF
  943. skip_entropy = {'kstwobign', 'pearson3'} # tolerance issue
  944. skip_skewness = {'exponpow', 'ksone', 'nchypergeom_wallenius'} # tolerance
  945. skip_kurtosis = {'chi', 'exponpow', 'invgamma', # tolerance
  946. 'johnsonsb', 'ksone', 'kstwo', # tolerance
  947. 'nchypergeom_wallenius'} # tolerance
  948. skip_logccdf = {'arcsine', 'skewcauchy', 'trapezoid', 'triang'} # tolerance
  949. skip_raw = {2: {'alpha', 'foldcauchy', 'halfcauchy', 'levy', 'levy_l'},
  950. 3: {'pareto'}, # stats.pareto is just wrong
  951. 4: {'invgamma'}} # tolerance issue
  952. skip_standardized = {'exponpow', 'ksone'} # tolerances
  953. dist = getattr(stats, distname)
  954. params = dict(zip(dist.shapes.split(', '), distdata[1])) if dist.shapes else {}
  955. rng = np.random.default_rng(7548723590230982)
  956. CustomDistribution = stats.make_distribution(dist)
  957. X = CustomDistribution(**params)
  958. Y = dist(**params)
  959. x = X.sample(shape=10, rng=rng)
  960. p = X.cdf(x)
  961. rtol = custom_tolerances.get(distname, 1e-7)
  962. atol = 1e-12
  963. with np.errstate(divide='ignore', invalid='ignore'):
  964. m, v, s, k = Y.stats('mvsk')
  965. assert_allclose(X.support(), Y.support())
  966. if distname not in skip_entropy:
  967. assert_allclose(X.entropy(), Y.entropy(), rtol=rtol)
  968. if isinstance(Y, stats.rv_discrete):
  969. # some continuous distributions have trouble with `logentropy` because
  970. # it uses complex numbers
  971. assert_allclose(np.exp(X.logentropy()), Y.entropy(), rtol=rtol)
  972. assert_allclose(X.median(), Y.median(), rtol=rtol)
  973. assert_allclose(X.mean(), m, rtol=rtol, atol=atol)
  974. assert_allclose(X.variance(), v, rtol=rtol, atol=atol)
  975. if distname not in skip_skewness:
  976. assert_allclose(X.skewness(), s, rtol=rtol, atol=atol)
  977. if distname not in skip_kurtosis:
  978. assert_allclose(X.kurtosis(convention='excess'), k,
  979. rtol=rtol, atol=atol)
  980. if isinstance(dist, stats.rv_continuous):
  981. assert_allclose(X.logpdf(x), Y.logpdf(x), rtol=rtol)
  982. assert_allclose(X.pdf(x), Y.pdf(x), rtol=rtol)
  983. else:
  984. assert_allclose(X.logpmf(x), Y.logpmf(x), rtol=rtol)
  985. assert_allclose(X.pmf(x), Y.pmf(x), rtol=rtol)
  986. assert_allclose(X.logcdf(x), Y.logcdf(x), rtol=rtol)
  987. assert_allclose(X.cdf(x), Y.cdf(x), rtol=rtol)
  988. if distname not in skip_logccdf:
  989. assert_allclose(X.logccdf(x), Y.logsf(x), rtol=rtol)
  990. assert_allclose(X.ccdf(x), Y.sf(x), rtol=rtol)
  991. # old infrastructure convention for ppf(p=0) and isf(p=1) is different than
  992. # new infrastructure. Adjust reference values accordingly.
  993. a, _ = Y.support()
  994. ref_ppf = Y.ppf(p)
  995. ref_ppf[p == 0] = a
  996. ref_isf = Y.isf(p)
  997. ref_isf[p == 1] = a
  998. assert_allclose(X.icdf(p), ref_ppf, rtol=rtol)
  999. assert_allclose(X.iccdf(p), ref_isf, rtol=rtol)
  1000. for order in range(5):
  1001. if distname not in skip_raw.get(order, {}):
  1002. assert_allclose(X.moment(order, kind='raw'),
  1003. Y.moment(order), rtol=rtol, atol=atol)
  1004. for order in range(3, 4):
  1005. if distname not in skip_standardized:
  1006. assert_allclose(X.moment(order, kind='standardized'),
  1007. Y.stats('mvsk'[order-1]), rtol=rtol, atol=atol)
  1008. if isinstance(dist, stats.rv_continuous):
  1009. # For discrete distributions, these won't agree at the far left end
  1010. # of the support, and the new infrastructure is slow there (for now).
  1011. seed = 845298245687345
  1012. assert_allclose(X.sample(shape=10, rng=seed),
  1013. Y.rvs(size=10,
  1014. random_state=np.random.default_rng(seed)),
  1015. rtol=rtol)
  1016. def test_custom(self):
  1017. rng = np.random.default_rng(7548723590230982)
  1018. class MyLogUniform:
  1019. @property
  1020. def __make_distribution_version__(self):
  1021. return "1.16.0"
  1022. @property
  1023. def parameters(self):
  1024. return {'a': {'endpoints': (0, np.inf), 'inclusive': (False, False)},
  1025. 'b': {'endpoints': ('a', np.inf), 'inclusive': (False, False)}}
  1026. @property
  1027. def support(self):
  1028. return {'endpoints': ('a', 'b')}
  1029. def pdf(self, x, a, b):
  1030. return 1 / (x * (np.log(b) - np.log(a)))
  1031. def sample(self, shape, *, a, b, rng=None):
  1032. p = rng.uniform(size=shape)
  1033. return np.exp(np.log(a) + p * (np.log(b) - np.log(a)))
  1034. def moment(self, order, kind='raw', *, a, b):
  1035. if order == 1 and kind == 'raw':
  1036. # quadrature is perfectly accurate here; add 1e-10 error so we
  1037. # can tell the difference between the two
  1038. return (b - a) / np.log(b/a) + 1e-10
  1039. LogUniform = stats.make_distribution(MyLogUniform())
  1040. X = LogUniform(a=1., b=np.e)
  1041. Y = stats.exp(Uniform(a=0., b=1.))
  1042. # pre-2.0 support is not needed for much longer, so let's just test with 2.0+
  1043. if np.__version__ >= "2.0":
  1044. assert str(X) == f"MyLogUniform(a=1.0, b={np.e})"
  1045. assert repr(X) == f"MyLogUniform(a=np.float64(1.0), b=np.float64({np.e}))"
  1046. x = X.sample(shape=10, rng=rng)
  1047. p = X.cdf(x)
  1048. assert_allclose(X.support(), Y.support())
  1049. assert_allclose(X.entropy(), Y.entropy())
  1050. assert_allclose(X.median(), Y.median())
  1051. assert_allclose(X.logpdf(x), Y.logpdf(x))
  1052. assert_allclose(X.pdf(x), Y.pdf(x))
  1053. assert_allclose(X.logcdf(x), Y.logcdf(x))
  1054. assert_allclose(X.cdf(x), Y.cdf(x))
  1055. assert_allclose(X.logccdf(x), Y.logccdf(x))
  1056. assert_allclose(X.ccdf(x), Y.ccdf(x))
  1057. assert_allclose(X.icdf(p), Y.icdf(p))
  1058. assert_allclose(X.iccdf(p), Y.iccdf(p))
  1059. for kind in ['raw', 'central', 'standardized']:
  1060. for order in range(5):
  1061. assert_allclose(X.moment(order, kind=kind),
  1062. Y.moment(order, kind=kind))
  1063. # Confirm that the `sample` and `moment` methods are overriden as expected
  1064. sample_formula = X.sample(shape=10, rng=0, method='formula')
  1065. sample_inverse = X.sample(shape=10, rng=0, method='inverse_transform')
  1066. assert_allclose(sample_formula, sample_inverse)
  1067. assert not np.all(sample_formula == sample_inverse)
  1068. assert_allclose(X.mean(method='formula'), X.mean(method='quadrature'))
  1069. assert not X.mean(method='formula') == X.mean(method='quadrature')
  1070. # pdf and cdf formulas below can warn on boundary of support in some cases.
  1071. # See https://github.com/scipy/scipy/pull/22560#discussion_r1962763840.
  1072. @pytest.mark.slow
  1073. @pytest.mark.filterwarnings("ignore::RuntimeWarning")
  1074. @pytest.mark.parametrize("c", [-1, 0, 1, np.asarray([-2.1, -1., 0., 1., 2.1])])
  1075. def test_custom_variable_support(self, c):
  1076. rng = np.random.default_rng(7548723590230982)
  1077. class MyGenExtreme:
  1078. @property
  1079. def __make_distribution_version__(self):
  1080. return "1.16.0"
  1081. @property
  1082. def parameters(self):
  1083. return {
  1084. 'c': {'endpoints': (-np.inf, np.inf), 'inclusive': (False, False)},
  1085. 'mu': {'endpoints': (-np.inf, np.inf), 'inclusive': (False, False)},
  1086. 'sigma': {'endpoints': (0, np.inf), 'inclusive': (False, False)}
  1087. }
  1088. @property
  1089. def support(self):
  1090. def left(*, c, mu, sigma):
  1091. c, mu, sigma = np.broadcast_arrays(c, mu, sigma)
  1092. result = np.empty_like(c)
  1093. result[c >= 0] = -np.inf
  1094. result[c < 0] = mu[c < 0] + sigma[c < 0] / c[c < 0]
  1095. return result[()]
  1096. def right(*, c, mu, sigma):
  1097. c, mu, sigma = np.broadcast_arrays(c, mu, sigma)
  1098. result = np.empty_like(c)
  1099. result[c <= 0] = np.inf
  1100. result[c > 0] = mu[c > 0] + sigma[c > 0] / c[c > 0]
  1101. return result[()]
  1102. return {"endpoints": (left, right), "inclusive": (False, False)}
  1103. def pdf(self, x, *, c, mu, sigma):
  1104. x, c, mu, sigma = np.broadcast_arrays(x, c, mu, sigma)
  1105. t = np.empty_like(x)
  1106. mask = (c == 0)
  1107. t[mask] = np.exp(-(x[mask] - mu[mask])/sigma[mask])
  1108. t[~mask] = (
  1109. 1 - c[~mask]*(x[~mask] - mu[~mask])/sigma[~mask]
  1110. )**(1/c[~mask])
  1111. result = 1/sigma * t**(1 - c)*np.exp(-t)
  1112. return result[()]
  1113. def cdf(self, x, *, c, mu, sigma):
  1114. x, c, mu, sigma = np.broadcast_arrays(x, c, mu, sigma)
  1115. t = np.empty_like(x)
  1116. mask = (c == 0)
  1117. t[mask] = np.exp(-(x[mask] - mu[mask])/sigma[mask])
  1118. t[~mask] = (
  1119. 1 - c[~mask]*(x[~mask] - mu[~mask])/sigma[~mask]
  1120. )**(1/c[~mask])
  1121. return np.exp(-t)[()]
  1122. GenExtreme1 = stats.make_distribution(MyGenExtreme())
  1123. GenExtreme2 = stats.make_distribution(stats.genextreme)
  1124. X1 = GenExtreme1(c=c, mu=0, sigma=1)
  1125. X2 = GenExtreme2(c=c)
  1126. x = X1.sample(shape=10, rng=rng)
  1127. p = X1.cdf(x)
  1128. assert_allclose(X1.support(), X2.support())
  1129. assert_allclose(X1.entropy(), X2.entropy(), rtol=5e-6)
  1130. assert_allclose(X1.median(), X2.median())
  1131. assert_allclose(X1.logpdf(x), X2.logpdf(x))
  1132. assert_allclose(X1.pdf(x), X2.pdf(x))
  1133. assert_allclose(X1.logcdf(x), X2.logcdf(x))
  1134. assert_allclose(X1.cdf(x), X2.cdf(x))
  1135. assert_allclose(X1.logccdf(x), X2.logccdf(x))
  1136. assert_allclose(X1.ccdf(x), X2.ccdf(x))
  1137. assert_allclose(X1.icdf(p), X2.icdf(p))
  1138. assert_allclose(X1.iccdf(p), X2.iccdf(p))
  1139. @pytest.mark.slow
  1140. @pytest.mark.parametrize("a", [0.5, np.asarray([0.5, 1.0, 2.0, 4.0, 8.0])])
  1141. @pytest.mark.parametrize("b", [0.5, np.asarray([0.5, 1.0, 2.0, 4.0, 8.0])])
  1142. def test_custom_multiple_parameterizations(self, a, b):
  1143. rng = np.random.default_rng(7548723590230982)
  1144. class MyBeta:
  1145. @property
  1146. def __make_distribution_version__(self):
  1147. return "1.16.0"
  1148. @property
  1149. def parameters(self):
  1150. return (
  1151. {"a": (0, np.inf), "b": (0, np.inf)},
  1152. {"mu": (0, 1), "nu": (0, np.inf)},
  1153. )
  1154. def process_parameters(self, a=None, b=None, mu=None, nu=None):
  1155. if a is not None and b is not None and mu is None and nu is None:
  1156. nu = a + b
  1157. mu = a / nu
  1158. else:
  1159. a = mu * nu
  1160. b = nu - a
  1161. return {"a": a, "b": b, "mu": mu, "nu": nu}
  1162. @property
  1163. def support(self):
  1164. return {'endpoints': (0, 1)}
  1165. def pdf(self, x, a, b, mu, nu):
  1166. return special._ufuncs._beta_pdf(x, a, b)
  1167. def cdf(self, x, a, b, mu, nu):
  1168. return special.betainc(a, b, x)
  1169. Beta = stats.make_distribution(stats.beta)
  1170. MyBeta = stats.make_distribution(MyBeta())
  1171. mu = a / (a + b)
  1172. nu = a + b
  1173. X = MyBeta(a=a, b=b)
  1174. Y = MyBeta(mu=mu, nu=nu)
  1175. Z = Beta(a=a, b=b)
  1176. x = Z.sample(shape=10, rng=rng)
  1177. p = Z.cdf(x)
  1178. assert_allclose(X.support(), Z.support())
  1179. assert_allclose(X.median(), Z.median())
  1180. assert_allclose(X.pdf(x), Z.pdf(x))
  1181. assert_allclose(X.cdf(x), Z.cdf(x))
  1182. assert_allclose(X.ccdf(x), Z.ccdf(x))
  1183. assert_allclose(X.icdf(p), Z.icdf(p))
  1184. assert_allclose(X.iccdf(p), Z.iccdf(p))
  1185. assert_allclose(Y.support(), Z.support())
  1186. assert_allclose(Y.median(), Z.median())
  1187. assert_allclose(Y.pdf(x), Z.pdf(x))
  1188. assert_allclose(Y.cdf(x), Z.cdf(x))
  1189. assert_allclose(Y.ccdf(x), Z.ccdf(x))
  1190. assert_allclose(Y.icdf(p), Z.icdf(p))
  1191. assert_allclose(Y.iccdf(p), Z.iccdf(p))
  1192. def test_input_validation(self):
  1193. message = '`levy_stable` is not supported.'
  1194. with pytest.raises(NotImplementedError, match=message):
  1195. stats.make_distribution(stats.levy_stable)
  1196. message = '`vonmises` is not supported.'
  1197. with pytest.raises(NotImplementedError, match=message):
  1198. stats.make_distribution(stats.vonmises)
  1199. message = "The argument must be an instance of..."
  1200. with pytest.raises(ValueError, match=message):
  1201. stats.make_distribution(object())
  1202. def test_repr_str_docs(self):
  1203. from scipy.stats._distribution_infrastructure import _distribution_names
  1204. for dist in _distribution_names.keys():
  1205. assert hasattr(stats, dist)
  1206. dist = stats.make_distribution(stats.gamma)
  1207. assert str(dist(a=2)) == "Gamma(a=2.0)"
  1208. if np.__version__ >= "2":
  1209. assert repr(dist(a=2)) == "Gamma(a=np.float64(2.0))"
  1210. assert 'Gamma' in dist.__doc__
  1211. dist = stats.make_distribution(stats.halfgennorm)
  1212. assert str(dist(beta=2)) == "HalfGeneralizedNormal(beta=2.0)"
  1213. if np.__version__ >= "2":
  1214. assert repr(dist(beta=2)) == "HalfGeneralizedNormal(beta=np.float64(2.0))"
  1215. assert 'HalfGeneralizedNormal' in dist.__doc__
  1216. class TestTransforms:
  1217. def test_ContinuousDistribution_only(self):
  1218. X = stats.Binomial(n=10, p=0.5)
  1219. # This is applied at the top level TransformedDistribution,
  1220. # so testing one subclass is enough
  1221. message = "Transformations are currently only supported for continuous RVs."
  1222. with pytest.raises(NotImplementedError, match=message):
  1223. stats.exp(X)
  1224. def test_truncate(self):
  1225. rng = np.random.default_rng(81345982345826)
  1226. lb = rng.random((3, 1))
  1227. ub = rng.random((3, 1))
  1228. lb, ub = np.minimum(lb, ub), np.maximum(lb, ub)
  1229. Y = stats.truncate(Normal(), lb=lb, ub=ub)
  1230. Y0 = stats.truncnorm(lb, ub)
  1231. y = Y0.rvs((3, 10), random_state=rng)
  1232. p = Y0.cdf(y)
  1233. assert_allclose(Y.logentropy(), np.log(Y0.entropy() + 0j))
  1234. assert_allclose(Y.entropy(), Y0.entropy())
  1235. assert_allclose(Y.median(), Y0.ppf(0.5))
  1236. assert_allclose(Y.mean(), Y0.mean())
  1237. assert_allclose(Y.variance(), Y0.var())
  1238. assert_allclose(Y.standard_deviation(), np.sqrt(Y0.var()))
  1239. assert_allclose(Y.skewness(), Y0.stats('s'))
  1240. assert_allclose(Y.kurtosis(), Y0.stats('k') + 3)
  1241. assert_allclose(Y.support(), Y0.support())
  1242. assert_allclose(Y.pdf(y), Y0.pdf(y))
  1243. assert_allclose(Y.cdf(y), Y0.cdf(y))
  1244. assert_allclose(Y.ccdf(y), Y0.sf(y))
  1245. assert_allclose(Y.icdf(p), Y0.ppf(p))
  1246. assert_allclose(Y.iccdf(p), Y0.isf(p))
  1247. assert_allclose(Y.logpdf(y), Y0.logpdf(y))
  1248. assert_allclose(Y.logcdf(y), Y0.logcdf(y))
  1249. assert_allclose(Y.logccdf(y), Y0.logsf(y))
  1250. assert_allclose(Y.ilogcdf(np.log(p)), Y0.ppf(p))
  1251. assert_allclose(Y.ilogccdf(np.log(p)), Y0.isf(p))
  1252. sample = Y.sample(10)
  1253. assert np.all((sample > lb) & (sample < ub))
  1254. @pytest.mark.fail_slow(10)
  1255. @given(data=strategies.data(), seed=strategies.integers(min_value=0))
  1256. def test_loc_scale(self, data, seed):
  1257. # Need tests with negative scale
  1258. rng = np.random.default_rng(seed)
  1259. class TransformedNormal(ShiftedScaledDistribution):
  1260. def __init__(self, *args, **kwargs):
  1261. super().__init__(StandardNormal(), *args, **kwargs)
  1262. tmp = draw_distribution_from_family(
  1263. TransformedNormal, data, rng, proportions=(1, 0, 0, 0), min_side=1)
  1264. dist, x, y, p, logp, result_shape, x_result_shape, xy_result_shape = tmp
  1265. loc = dist.loc
  1266. scale = dist.scale
  1267. dist0 = StandardNormal()
  1268. dist_ref = stats.norm(loc=loc, scale=scale)
  1269. x0 = (x - loc) / scale
  1270. y0 = (y - loc) / scale
  1271. a, b = dist.support()
  1272. a0, b0 = dist0.support()
  1273. assert_allclose(a, a0 + loc)
  1274. assert_allclose(b, b0 + loc)
  1275. with np.errstate(invalid='ignore', divide='ignore'):
  1276. assert_allclose(np.exp(dist.logentropy()), dist.entropy())
  1277. assert_allclose(dist.entropy(), dist_ref.entropy())
  1278. assert_allclose(dist.median(), dist0.median() + loc)
  1279. assert_allclose(dist.mode(), dist0.mode() + loc)
  1280. assert_allclose(dist.mean(), dist0.mean() + loc)
  1281. assert_allclose(dist.variance(), dist0.variance() * scale**2)
  1282. assert_allclose(dist.standard_deviation(), dist.variance()**0.5)
  1283. assert_allclose(dist.skewness(), dist0.skewness() * np.sign(scale))
  1284. assert_allclose(dist.kurtosis(), dist0.kurtosis())
  1285. assert_allclose(dist.logpdf(x), dist0.logpdf(x0) - np.log(scale))
  1286. assert_allclose(dist.pdf(x), dist0.pdf(x0) / scale)
  1287. assert_allclose(dist.logcdf(x), dist0.logcdf(x0))
  1288. assert_allclose(dist.cdf(x), dist0.cdf(x0))
  1289. assert_allclose(dist.logccdf(x), dist0.logccdf(x0))
  1290. assert_allclose(dist.ccdf(x), dist0.ccdf(x0))
  1291. assert_allclose(dist.logcdf(x, y), dist0.logcdf(x0, y0))
  1292. assert_allclose(dist.cdf(x, y), dist0.cdf(x0, y0))
  1293. assert_allclose(dist.logccdf(x, y), dist0.logccdf(x0, y0))
  1294. assert_allclose(dist.ccdf(x, y), dist0.ccdf(x0, y0))
  1295. assert_allclose(dist.ilogcdf(logp), dist0.ilogcdf(logp)*scale + loc)
  1296. assert_allclose(dist.icdf(p), dist0.icdf(p)*scale + loc)
  1297. assert_allclose(dist.ilogccdf(logp), dist0.ilogccdf(logp)*scale + loc)
  1298. assert_allclose(dist.iccdf(p), dist0.iccdf(p)*scale + loc)
  1299. for i in range(1, 5):
  1300. assert_allclose(dist.moment(i, 'raw'), dist_ref.moment(i))
  1301. assert_allclose(dist.moment(i, 'central'),
  1302. dist0.moment(i, 'central') * scale**i)
  1303. assert_allclose(dist.moment(i, 'standardized'),
  1304. dist0.moment(i, 'standardized') * np.sign(scale)**i)
  1305. # Transform back to the original distribution using all arithmetic
  1306. # operations; check that it behaves as expected.
  1307. dist = (dist - 2*loc) + loc
  1308. dist = dist/scale**2 * scale
  1309. z = np.zeros(dist._shape) # compact broadcasting
  1310. a, b = dist.support()
  1311. a0, b0 = dist0.support()
  1312. assert_allclose(a, a0 + z)
  1313. assert_allclose(b, b0 + z)
  1314. with np.errstate(invalid='ignore', divide='ignore'):
  1315. assert_allclose(dist.logentropy(), dist0.logentropy() + z)
  1316. assert_allclose(dist.entropy(), dist0.entropy() + z)
  1317. assert_allclose(dist.median(), dist0.median() + z)
  1318. assert_allclose(dist.mode(), dist0.mode() + z)
  1319. assert_allclose(dist.mean(), dist0.mean() + z)
  1320. assert_allclose(dist.variance(), dist0.variance() + z)
  1321. assert_allclose(dist.standard_deviation(), dist0.standard_deviation() + z)
  1322. assert_allclose(dist.skewness(), dist0.skewness() + z)
  1323. assert_allclose(dist.kurtosis(), dist0.kurtosis() + z)
  1324. assert_allclose(dist.logpdf(x), dist0.logpdf(x)+z)
  1325. assert_allclose(dist.pdf(x), dist0.pdf(x) + z)
  1326. assert_allclose(dist.logcdf(x), dist0.logcdf(x) + z)
  1327. assert_allclose(dist.cdf(x), dist0.cdf(x) + z)
  1328. assert_allclose(dist.logccdf(x), dist0.logccdf(x) + z)
  1329. assert_allclose(dist.ccdf(x), dist0.ccdf(x) + z)
  1330. assert_allclose(dist.ilogcdf(logp), dist0.ilogcdf(logp) + z)
  1331. assert_allclose(dist.icdf(p), dist0.icdf(p) + z)
  1332. assert_allclose(dist.ilogccdf(logp), dist0.ilogccdf(logp) + z)
  1333. assert_allclose(dist.iccdf(p), dist0.iccdf(p) + z)
  1334. for i in range(1, 5):
  1335. assert_allclose(dist.moment(i, 'raw'), dist0.moment(i, 'raw'))
  1336. assert_allclose(dist.moment(i, 'central'), dist0.moment(i, 'central'))
  1337. assert_allclose(dist.moment(i, 'standardized'),
  1338. dist0.moment(i, 'standardized'))
  1339. # These are tough to compare because of the way the shape works
  1340. # rng = np.random.default_rng(seed)
  1341. # rng0 = np.random.default_rng(seed)
  1342. # assert_allclose(dist.sample(x_result_shape, rng=rng),
  1343. # dist0.sample(x_result_shape, rng=rng0) * scale + loc)
  1344. # Should also try to test fit, plot?
  1345. @pytest.mark.fail_slow(5)
  1346. @pytest.mark.parametrize('exp_pow', ['exp', 'pow'])
  1347. def test_exp_pow(self, exp_pow):
  1348. rng = np.random.default_rng(81345982345826)
  1349. mu = rng.random((3, 1))
  1350. sigma = rng.random((3, 1))
  1351. X = Normal()*sigma + mu
  1352. if exp_pow == 'exp':
  1353. Y = stats.exp(X)
  1354. else:
  1355. Y = np.e ** X
  1356. Y0 = stats.lognorm(sigma, scale=np.exp(mu))
  1357. y = Y0.rvs((3, 10), random_state=rng)
  1358. p = Y0.cdf(y)
  1359. assert_allclose(Y.logentropy(), np.log(Y0.entropy()))
  1360. assert_allclose(Y.entropy(), Y0.entropy())
  1361. assert_allclose(Y.median(), Y0.ppf(0.5))
  1362. assert_allclose(Y.mean(), Y0.mean())
  1363. assert_allclose(Y.variance(), Y0.var())
  1364. assert_allclose(Y.standard_deviation(), np.sqrt(Y0.var()))
  1365. assert_allclose(Y.skewness(), Y0.stats('s'))
  1366. assert_allclose(Y.kurtosis(), Y0.stats('k') + 3)
  1367. assert_allclose(Y.support(), Y0.support())
  1368. assert_allclose(Y.pdf(y), Y0.pdf(y))
  1369. assert_allclose(Y.cdf(y), Y0.cdf(y))
  1370. assert_allclose(Y.ccdf(y), Y0.sf(y))
  1371. assert_allclose(Y.icdf(p), Y0.ppf(p))
  1372. assert_allclose(Y.iccdf(p), Y0.isf(p))
  1373. assert_allclose(Y.logpdf(y), Y0.logpdf(y))
  1374. assert_allclose(Y.logcdf(y), Y0.logcdf(y))
  1375. assert_allclose(Y.logccdf(y), Y0.logsf(y))
  1376. assert_allclose(Y.ilogcdf(np.log(p)), Y0.ppf(p))
  1377. assert_allclose(Y.ilogccdf(np.log(p)), Y0.isf(p))
  1378. seed = 3984593485
  1379. assert_allclose(Y.sample(rng=seed), np.exp(X.sample(rng=seed)))
  1380. @pytest.mark.fail_slow(10)
  1381. @pytest.mark.parametrize('scale', [1, 2, -1])
  1382. @pytest.mark.xfail_on_32bit("`scale=-1` fails on 32-bit; needs investigation")
  1383. def test_reciprocal(self, scale):
  1384. rng = np.random.default_rng(81345982345826)
  1385. a = rng.random((3, 1))
  1386. # Separate sign from scale. It's easy to scale the resulting
  1387. # RV with negative scale; we want to test the ability to divide
  1388. # by a RV with negative support
  1389. sign, scale = np.sign(scale), abs(scale)
  1390. # Reference distribution
  1391. InvGamma = stats.make_distribution(stats.invgamma)
  1392. Y0 = sign * scale * InvGamma(a=a)
  1393. # Test distribution
  1394. X = _Gamma(a=a) if sign > 0 else -_Gamma(a=a)
  1395. Y = scale / X
  1396. y = Y0.sample(shape=(3, 10), rng=rng)
  1397. p = Y0.cdf(y)
  1398. logp = np.log(p)
  1399. assert_allclose(Y.logentropy(), np.log(Y0.entropy()))
  1400. assert_allclose(Y.entropy(), Y0.entropy())
  1401. assert_allclose(Y.median(), Y0.median())
  1402. # moments are not finite
  1403. assert_allclose(Y.support(), Y0.support())
  1404. assert_allclose(Y.pdf(y), Y0.pdf(y))
  1405. assert_allclose(Y.cdf(y), Y0.cdf(y))
  1406. assert_allclose(Y.ccdf(y), Y0.ccdf(y))
  1407. assert_allclose(Y.icdf(p), Y0.icdf(p))
  1408. assert_allclose(Y.iccdf(p), Y0.iccdf(p))
  1409. assert_allclose(Y.logpdf(y), Y0.logpdf(y))
  1410. assert_allclose(Y.logcdf(y), Y0.logcdf(y))
  1411. assert_allclose(Y.logccdf(y), Y0.logccdf(y))
  1412. with np.errstate(divide='ignore', invalid='ignore'):
  1413. assert_allclose(Y.ilogcdf(logp), Y0.ilogcdf(logp))
  1414. assert_allclose(Y.ilogccdf(logp), Y0.ilogccdf(logp))
  1415. seed = 3984593485
  1416. assert_allclose(Y.sample(rng=seed), scale/(X.sample(rng=seed)))
  1417. @pytest.mark.fail_slow(5)
  1418. def test_log(self):
  1419. rng = np.random.default_rng(81345982345826)
  1420. a = rng.random((3, 1))
  1421. X = _Gamma(a=a)
  1422. Y0 = stats.loggamma(a)
  1423. Y = stats.log(X)
  1424. y = Y0.rvs((3, 10), random_state=rng)
  1425. p = Y0.cdf(y)
  1426. assert_allclose(Y.logentropy(), np.log(Y0.entropy()))
  1427. assert_allclose(Y.entropy(), Y0.entropy())
  1428. assert_allclose(Y.median(), Y0.ppf(0.5))
  1429. assert_allclose(Y.mean(), Y0.mean())
  1430. assert_allclose(Y.variance(), Y0.var())
  1431. assert_allclose(Y.standard_deviation(), np.sqrt(Y0.var()))
  1432. assert_allclose(Y.skewness(), Y0.stats('s'))
  1433. assert_allclose(Y.kurtosis(), Y0.stats('k') + 3)
  1434. assert_allclose(Y.support(), Y0.support())
  1435. assert_allclose(Y.pdf(y), Y0.pdf(y))
  1436. assert_allclose(Y.cdf(y), Y0.cdf(y))
  1437. assert_allclose(Y.ccdf(y), Y0.sf(y))
  1438. assert_allclose(Y.icdf(p), Y0.ppf(p))
  1439. assert_allclose(Y.iccdf(p), Y0.isf(p))
  1440. assert_allclose(Y.logpdf(y), Y0.logpdf(y))
  1441. assert_allclose(Y.logcdf(y), Y0.logcdf(y))
  1442. assert_allclose(Y.logccdf(y), Y0.logsf(y))
  1443. with np.errstate(invalid='ignore'):
  1444. assert_allclose(Y.ilogcdf(np.log(p)), Y0.ppf(p))
  1445. assert_allclose(Y.ilogccdf(np.log(p)), Y0.isf(p))
  1446. seed = 3984593485
  1447. assert_allclose(Y.sample(rng=seed), np.log(X.sample(rng=seed)))
  1448. def test_monotonic_transforms(self):
  1449. # Some tests of monotonic transforms that are better to be grouped or
  1450. # don't fit well above
  1451. X = Uniform(a=1, b=2)
  1452. X_str = "Uniform(a=1.0, b=2.0)"
  1453. assert str(stats.log(X)) == f"log({X_str})"
  1454. assert str(1 / X) == f"1/({X_str})"
  1455. assert str(stats.exp(X)) == f"exp({X_str})"
  1456. X = Uniform(a=-1, b=2)
  1457. message = "Division by a random variable is only implemented when the..."
  1458. with pytest.raises(NotImplementedError, match=message):
  1459. 1 / X
  1460. message = "The logarithm of a random variable is only implemented when the..."
  1461. with pytest.raises(NotImplementedError, match=message):
  1462. stats.log(X)
  1463. message = "Raising an argument to the power of a random variable is only..."
  1464. with pytest.raises(NotImplementedError, match=message):
  1465. (-2) ** X
  1466. with pytest.raises(NotImplementedError, match=message):
  1467. 1 ** X
  1468. with pytest.raises(NotImplementedError, match=message):
  1469. [0.5, 1.5] ** X
  1470. message = "Raising a random variable to the power of an argument is only"
  1471. with pytest.raises(NotImplementedError, match=message):
  1472. X ** (-2)
  1473. with pytest.raises(NotImplementedError, match=message):
  1474. X ** 0
  1475. with pytest.raises(NotImplementedError, match=message):
  1476. X ** [0.5, 1.5]
  1477. def test_arithmetic_operators(self):
  1478. rng = np.random.default_rng(2348923495832349834)
  1479. a, b, loc, scale = 0.294, 1.34, 0.57, 1.16
  1480. x = rng.uniform(-3, 3, 100)
  1481. Y = _LogUniform(a=a, b=b)
  1482. X = scale*Y + loc
  1483. assert_allclose(X.cdf(x), Y.cdf((x - loc) / scale))
  1484. X = loc + Y*scale
  1485. assert_allclose(X.cdf(x), Y.cdf((x - loc) / scale))
  1486. X = Y/scale - loc
  1487. assert_allclose(X.cdf(x), Y.cdf((x + loc) * scale))
  1488. X = loc -_LogUniform(a=a, b=b)/scale
  1489. assert_allclose(X.cdf(x), Y.ccdf((-x + loc)*scale))
  1490. def test_abs(self):
  1491. rng = np.random.default_rng(81345982345826)
  1492. loc = rng.random((3, 1))
  1493. Y = stats.abs(Normal() + loc)
  1494. Y0 = stats.foldnorm(loc)
  1495. y = Y0.rvs((3, 10), random_state=rng)
  1496. p = Y0.cdf(y)
  1497. assert_allclose(Y.logentropy(), np.log(Y0.entropy() + 0j))
  1498. assert_allclose(Y.entropy(), Y0.entropy())
  1499. assert_allclose(Y.median(), Y0.ppf(0.5))
  1500. assert_allclose(Y.mean(), Y0.mean())
  1501. assert_allclose(Y.variance(), Y0.var())
  1502. assert_allclose(Y.standard_deviation(), np.sqrt(Y0.var()))
  1503. assert_allclose(Y.skewness(), Y0.stats('s'))
  1504. assert_allclose(Y.kurtosis(), Y0.stats('k') + 3)
  1505. assert_allclose(Y.support(), Y0.support())
  1506. assert_allclose(Y.pdf(y), Y0.pdf(y))
  1507. assert_allclose(Y.cdf(y), Y0.cdf(y))
  1508. assert_allclose(Y.ccdf(y), Y0.sf(y))
  1509. assert_allclose(Y.icdf(p), Y0.ppf(p))
  1510. assert_allclose(Y.iccdf(p), Y0.isf(p))
  1511. assert_allclose(Y.logpdf(y), Y0.logpdf(y))
  1512. assert_allclose(Y.logcdf(y), Y0.logcdf(y))
  1513. assert_allclose(Y.logccdf(y), Y0.logsf(y))
  1514. assert_allclose(Y.ilogcdf(np.log(p)), Y0.ppf(p))
  1515. assert_allclose(Y.ilogccdf(np.log(p)), Y0.isf(p))
  1516. sample = Y.sample(10)
  1517. assert np.all(sample > 0)
  1518. def test_abs_finite_support(self):
  1519. # The original implementation of `FoldedDistribution` might evaluate
  1520. # the private distribution methods outside the support. Check that this
  1521. # is resolved.
  1522. Weibull = stats.make_distribution(stats.weibull_min)
  1523. X = Weibull(c=2)
  1524. Y = abs(-X)
  1525. assert_equal(X.logpdf(1), Y.logpdf(1))
  1526. assert_equal(X.pdf(1), Y.pdf(1))
  1527. assert_equal(X.logcdf(1), Y.logcdf(1))
  1528. assert_equal(X.cdf(1), Y.cdf(1))
  1529. assert_equal(X.logccdf(1), Y.logccdf(1))
  1530. assert_equal(X.ccdf(1), Y.ccdf(1))
  1531. def test_pow(self):
  1532. rng = np.random.default_rng(81345982345826)
  1533. Y = Normal()**2
  1534. Y0 = stats.chi2(df=1)
  1535. y = Y0.rvs(10, random_state=rng)
  1536. p = Y0.cdf(y)
  1537. assert_allclose(Y.logentropy(), np.log(Y0.entropy() + 0j), rtol=1e-6)
  1538. assert_allclose(Y.entropy(), Y0.entropy(), rtol=1e-6)
  1539. assert_allclose(Y.median(), Y0.median())
  1540. assert_allclose(Y.mean(), Y0.mean())
  1541. assert_allclose(Y.variance(), Y0.var())
  1542. assert_allclose(Y.standard_deviation(), np.sqrt(Y0.var()))
  1543. assert_allclose(Y.skewness(), Y0.stats('s'))
  1544. assert_allclose(Y.kurtosis(), Y0.stats('k') + 3)
  1545. assert_allclose(Y.support(), Y0.support())
  1546. assert_allclose(Y.pdf(y), Y0.pdf(y))
  1547. assert_allclose(Y.cdf(y), Y0.cdf(y))
  1548. assert_allclose(Y.ccdf(y), Y0.sf(y))
  1549. assert_allclose(Y.icdf(p), Y0.ppf(p))
  1550. assert_allclose(Y.iccdf(p), Y0.isf(p))
  1551. assert_allclose(Y.logpdf(y), Y0.logpdf(y))
  1552. assert_allclose(Y.logcdf(y), Y0.logcdf(y))
  1553. assert_allclose(Y.logccdf(y), Y0.logsf(y))
  1554. assert_allclose(Y.ilogcdf(np.log(p)), Y0.ppf(p))
  1555. assert_allclose(Y.ilogccdf(np.log(p)), Y0.isf(p))
  1556. sample = Y.sample(10)
  1557. assert np.all(sample > 0)
  1558. class TestOrderStatistic:
  1559. @pytest.mark.fail_slow(20) # Moments require integration
  1560. def test_order_statistic(self):
  1561. rng = np.random.default_rng(7546349802439582)
  1562. X = Uniform(a=0, b=1)
  1563. n = 5
  1564. r = np.asarray([[1], [3], [5]])
  1565. Y = stats.order_statistic(X, n=n, r=r)
  1566. Y0 = stats.beta(r, n + 1 - r)
  1567. y = Y0.rvs((3, 10), random_state=rng)
  1568. p = Y0.cdf(y)
  1569. # log methods need some attention before merge
  1570. assert_allclose(np.exp(Y.logentropy()), Y0.entropy())
  1571. assert_allclose(Y.entropy(), Y0.entropy())
  1572. assert_allclose(Y.mean(), Y0.mean())
  1573. assert_allclose(Y.variance(), Y0.var())
  1574. assert_allclose(Y.skewness(), Y0.stats('s'), atol=1e-15)
  1575. assert_allclose(Y.kurtosis(), Y0.stats('k') + 3, atol=1e-15)
  1576. assert_allclose(Y.median(), Y0.ppf(0.5))
  1577. assert_allclose(Y.support(), Y0.support())
  1578. assert_allclose(Y.pdf(y), Y0.pdf(y))
  1579. assert_allclose(Y.cdf(y, method='formula'), Y.cdf(y, method='quadrature'))
  1580. assert_allclose(Y.ccdf(y, method='formula'), Y.ccdf(y, method='quadrature'))
  1581. assert_allclose(Y.icdf(p, method='formula'), Y.icdf(p, method='inversion'))
  1582. assert_allclose(Y.iccdf(p, method='formula'), Y.iccdf(p, method='inversion'))
  1583. assert_allclose(Y.logpdf(y), Y0.logpdf(y))
  1584. assert_allclose(Y.logcdf(y), Y0.logcdf(y))
  1585. assert_allclose(Y.logccdf(y), Y0.logsf(y))
  1586. with np.errstate(invalid='ignore', divide='ignore'):
  1587. assert_allclose(Y.ilogcdf(np.log(p),), Y0.ppf(p))
  1588. assert_allclose(Y.ilogccdf(np.log(p)), Y0.isf(p))
  1589. message = "`r` and `n` must contain only positive integers."
  1590. with pytest.raises(ValueError, match=message):
  1591. stats.order_statistic(X, n=n, r=-1)
  1592. with pytest.raises(ValueError, match=message):
  1593. stats.order_statistic(X, n=-1, r=r)
  1594. with pytest.raises(ValueError, match=message):
  1595. stats.order_statistic(X, n=n, r=1.5)
  1596. with pytest.raises(ValueError, match=message):
  1597. stats.order_statistic(X, n=1.5, r=r)
  1598. def test_support_gh22037(self):
  1599. # During review of gh-22037, it was noted that the `support` of
  1600. # an `OrderStatisticDistribution` returned incorrect results;
  1601. # this was resolved by overriding `_support`.
  1602. Uniform = stats.make_distribution(stats.uniform)
  1603. X = Uniform()
  1604. Y = X*5 + 2
  1605. Z = stats.order_statistic(Y, r=3, n=5)
  1606. assert_allclose(Z.support(), Y.support())
  1607. def test_composition_gh22037(self):
  1608. # During review of gh-22037, it was noted that an error was
  1609. # raised when creating an `OrderStatisticDistribution` from
  1610. # a `TruncatedDistribution`. This was resolved by overriding
  1611. # `_update_parameters`.
  1612. Normal = stats.make_distribution(stats.norm)
  1613. TruncatedNormal = stats.make_distribution(stats.truncnorm)
  1614. a, b = [-2, -1], 1
  1615. r, n = 3, [[4], [5]]
  1616. x = [[[-0.3]], [[0.1]]]
  1617. X1 = Normal()
  1618. Y1 = stats.truncate(X1, a, b)
  1619. Z1 = stats.order_statistic(Y1, r=r, n=n)
  1620. X2 = TruncatedNormal(a=a, b=b)
  1621. Z2 = stats.order_statistic(X2, r=r, n=n)
  1622. np.testing.assert_allclose(Z1.cdf(x), Z2.cdf(x))
  1623. class TestFullCoverage:
  1624. # Adds tests just to get to 100% test coverage; this way it's more obvious
  1625. # if new lines are untested.
  1626. def test_Domain(self):
  1627. with pytest.raises(NotImplementedError):
  1628. _Domain.contains(None, 1.)
  1629. with pytest.raises(NotImplementedError):
  1630. _Domain.get_numerical_endpoints(None, 1.)
  1631. with pytest.raises(NotImplementedError):
  1632. _Domain.__str__(None)
  1633. def test_Parameter(self):
  1634. with pytest.raises(NotImplementedError):
  1635. _Parameter.validate(None, 1.)
  1636. @pytest.mark.parametrize(("dtype_in", "dtype_out"),
  1637. [(np.float16, np.float16),
  1638. (np.int16, np.float64)])
  1639. def test_RealParameter_uncommon_dtypes(self, dtype_in, dtype_out):
  1640. domain = _RealInterval((-1, 1))
  1641. parameter = _RealParameter('x', domain=domain)
  1642. x = np.asarray([0.5, 2.5], dtype=dtype_in)
  1643. arr, dtype, valid = parameter.validate(x, parameter_values={})
  1644. assert_equal(arr, x)
  1645. assert dtype == dtype_out
  1646. assert_equal(valid, [True, False])
  1647. def test_ContinuousDistribution_set_invalid_nan(self):
  1648. # Exercise code paths when formula returns wrong shape and dtype
  1649. # We could consider making this raise an error to force authors
  1650. # to return the right shape and dytpe, but this would need to be
  1651. # configurable.
  1652. class TestDist(ContinuousDistribution):
  1653. _variable = _RealParameter('x', domain=_RealInterval(endpoints=(0., 1.)))
  1654. def _logpdf_formula(self, x, *args, **kwargs):
  1655. return 0
  1656. X = TestDist()
  1657. dtype = np.float32
  1658. X._dtype = dtype
  1659. x = np.asarray([0.5], dtype=dtype)
  1660. assert X.logpdf(x).dtype == dtype
  1661. def test_fiinfo(self):
  1662. assert _fiinfo(np.float64(1.)).max == np.finfo(np.float64).max
  1663. assert _fiinfo(np.int64(1)).max == np.iinfo(np.int64).max
  1664. def test_generate_domain_support(self):
  1665. msg = _generate_domain_support(StandardNormal)
  1666. assert "accepts no distribution parameters" in msg
  1667. msg = _generate_domain_support(Normal)
  1668. assert "accepts one parameterization" in msg
  1669. msg = _generate_domain_support(_LogUniform)
  1670. assert "accepts two parameterizations" in msg
  1671. def test_ContinuousDistribution__repr__(self):
  1672. X = Uniform(a=0, b=1)
  1673. if np.__version__ < "2":
  1674. assert repr(X) == "Uniform(a=0.0, b=1.0)"
  1675. else:
  1676. assert repr(X) == "Uniform(a=np.float64(0.0), b=np.float64(1.0))"
  1677. if np.__version__ < "2":
  1678. assert repr(X*3 + 2) == "3.0*Uniform(a=0.0, b=1.0) + 2.0"
  1679. else:
  1680. assert repr(X*3 + 2) == (
  1681. "np.float64(3.0)*Uniform(a=np.float64(0.0), b=np.float64(1.0))"
  1682. " + np.float64(2.0)"
  1683. )
  1684. X = Uniform(a=np.zeros(4), b=1)
  1685. assert repr(X) == "Uniform(a=array([0., 0., 0., 0.]), b=1)"
  1686. X = Uniform(a=np.zeros(4, dtype=np.float32), b=np.ones(4, dtype=np.float32))
  1687. assert repr(X) == (
  1688. "Uniform(a=array([0., 0., 0., 0.], dtype=float32),"
  1689. " b=array([1., 1., 1., 1.], dtype=float32))"
  1690. )
  1691. class TestReprs:
  1692. U = Uniform(a=0, b=1)
  1693. V = Uniform(a=np.float32(0.0), b=np.float32(1.0))
  1694. X = Normal(mu=-1, sigma=1)
  1695. Y = Normal(mu=1, sigma=1)
  1696. Z = Normal(mu=np.zeros(1000), sigma=1)
  1697. @pytest.mark.parametrize(
  1698. "dist",
  1699. [
  1700. U,
  1701. U - np.array([1.0, 2.0]),
  1702. pytest.param(
  1703. V,
  1704. marks=pytest.mark.skipif(
  1705. np.__version__ < "2",
  1706. reason="numpy 1.x didn't have dtype in repr",
  1707. )
  1708. ),
  1709. pytest.param(
  1710. np.ones(2, dtype=np.float32)*V + np.zeros(2, dtype=np.float64),
  1711. marks=pytest.mark.skipif(
  1712. np.__version__ < "2",
  1713. reason="numpy 1.x didn't have dtype in repr",
  1714. )
  1715. ),
  1716. 3*U + 2,
  1717. U**4,
  1718. (3*U + 2)**4,
  1719. (3*U + 2)**3,
  1720. 2**U,
  1721. 2**(3*U + 1),
  1722. 1 / (1 + U),
  1723. stats.order_statistic(U, r=3, n=5),
  1724. stats.truncate(U, 0.2, 0.8),
  1725. stats.Mixture([X, Y], weights=[0.3, 0.7]),
  1726. abs(U),
  1727. stats.exp(U),
  1728. stats.log(1 + U),
  1729. np.array([1.0, 2.0])*U + np.array([2.0, 3.0]),
  1730. ]
  1731. )
  1732. def test_executable(self, dist):
  1733. # Test that reprs actually evaluate to proper distribution
  1734. # provided relevant imports are made.
  1735. from numpy import array # noqa: F401
  1736. from numpy import float32 # noqa: F401
  1737. from scipy.stats import abs, exp, log, order_statistic, truncate # noqa: F401
  1738. from scipy.stats import Mixture, Normal # noqa: F401
  1739. from scipy.stats._new_distributions import Uniform # noqa: F401
  1740. new_dist = eval(repr(dist))
  1741. # A basic check that the distributions are the same
  1742. sample1 = dist.sample(shape=10, rng=1234)
  1743. sample2 = new_dist.sample(shape=10, rng=1234)
  1744. assert_equal(sample1, sample2)
  1745. assert sample1.dtype is sample2.dtype
  1746. @pytest.mark.parametrize(
  1747. "dist",
  1748. [
  1749. Z,
  1750. np.full(1000, 2.0) * X + 1.0,
  1751. 2.0 * X + np.full(1000, 1.0),
  1752. np.full(1000, 2.0) * X + 1.0,
  1753. stats.truncate(Z, -1, 1),
  1754. stats.truncate(Z, -np.ones(1000), np.ones(1000)),
  1755. stats.order_statistic(X, r=np.arange(1, 1000), n=1000),
  1756. Z**2,
  1757. 1.0 / (1 + stats.exp(Z)),
  1758. 2**Z,
  1759. ]
  1760. )
  1761. def test_not_too_long(self, dist):
  1762. # Tests that array summarization is working to ensure reprs aren't too long.
  1763. # None of the reprs above will be executable.
  1764. assert len(repr(dist)) < 250
  1765. class MixedDist(ContinuousDistribution):
  1766. _variable = _RealParameter('x', domain=_RealInterval(endpoints=(-np.inf, np.inf)))
  1767. def _pdf_formula(self, x, *args, **kwargs):
  1768. return (0.4 * 1/(1.1 * np.sqrt(2*np.pi)) * np.exp(-0.5*((x+0.25)/1.1)**2)
  1769. + 0.6 * 1/(0.9 * np.sqrt(2*np.pi)) * np.exp(-0.5*((x-0.5)/0.9)**2))
  1770. class TestMixture:
  1771. def test_input_validation(self):
  1772. message = "`components` must contain at least one random variable."
  1773. with pytest.raises(ValueError, match=message):
  1774. Mixture([])
  1775. message = "Each element of `components` must be an instance..."
  1776. with pytest.raises(ValueError, match=message):
  1777. Mixture((1, 2, 3))
  1778. message = "All elements of `components` must have scalar shapes."
  1779. with pytest.raises(ValueError, match=message):
  1780. Mixture([Normal(mu=[1, 2]), Normal()])
  1781. message = "`components` and `weights` must have the same length."
  1782. with pytest.raises(ValueError, match=message):
  1783. Mixture([Normal()], weights=[0.5, 0.5])
  1784. message = "`weights` must have floating point dtype."
  1785. with pytest.raises(ValueError, match=message):
  1786. Mixture([Normal()], weights=[1])
  1787. message = "`weights` must have floating point dtype."
  1788. with pytest.raises(ValueError, match=message):
  1789. Mixture([Normal()], weights=[1])
  1790. message = "`weights` must sum to 1.0."
  1791. with pytest.raises(ValueError, match=message):
  1792. Mixture([Normal(), Normal()], weights=[0.5, 1.0])
  1793. message = "All `weights` must be non-negative."
  1794. with pytest.raises(ValueError, match=message):
  1795. Mixture([Normal(), Normal()], weights=[1.5, -0.5])
  1796. @pytest.mark.parametrize('shape', [(), (10,)])
  1797. def test_basic(self, shape):
  1798. rng = np.random.default_rng(582348972387243524)
  1799. X = Mixture((Normal(mu=-0.25, sigma=1.1), Normal(mu=0.5, sigma=0.9)),
  1800. weights=(0.4, 0.6))
  1801. Y = MixedDist()
  1802. x = rng.random(shape)
  1803. def assert_allclose(res, ref, **kwargs):
  1804. if shape == ():
  1805. assert np.isscalar(res)
  1806. np.testing.assert_allclose(res, ref, **kwargs)
  1807. assert_allclose(X.logentropy(), Y.logentropy())
  1808. assert_allclose(X.entropy(), Y.entropy())
  1809. assert_allclose(X.mode(), Y.mode())
  1810. assert_allclose(X.median(), Y.median())
  1811. assert_allclose(X.mean(), Y.mean())
  1812. assert_allclose(X.variance(), Y.variance())
  1813. assert_allclose(X.standard_deviation(), Y.standard_deviation())
  1814. assert_allclose(X.skewness(), Y.skewness())
  1815. assert_allclose(X.kurtosis(), Y.kurtosis())
  1816. assert_allclose(X.logpdf(x), Y.logpdf(x))
  1817. assert_allclose(X.pdf(x), Y.pdf(x))
  1818. assert_allclose(X.logcdf(x), Y.logcdf(x))
  1819. assert_allclose(X.cdf(x), Y.cdf(x))
  1820. assert_allclose(X.logccdf(x), Y.logccdf(x))
  1821. assert_allclose(X.ccdf(x), Y.ccdf(x))
  1822. assert_allclose(X.ilogcdf(x), Y.ilogcdf(x))
  1823. assert_allclose(X.icdf(x), Y.icdf(x))
  1824. assert_allclose(X.ilogccdf(x), Y.ilogccdf(x))
  1825. assert_allclose(X.iccdf(x), Y.iccdf(x))
  1826. for kind in ['raw', 'central', 'standardized']:
  1827. for order in range(5):
  1828. assert_allclose(X.moment(order, kind=kind),
  1829. Y.moment(order, kind=kind),
  1830. atol=1e-15)
  1831. # weak test of `sample`
  1832. shape = (10, 20, 5)
  1833. y = X.sample(shape, rng=rng)
  1834. assert y.shape == shape
  1835. assert stats.ks_1samp(y.ravel(), X.cdf).pvalue > 0.05
  1836. def test_default_weights(self):
  1837. a = 1.1
  1838. Gamma = stats.make_distribution(stats.gamma)
  1839. X = Gamma(a=a)
  1840. Y = stats.Mixture((X, -X))
  1841. x = np.linspace(-4, 4, 300)
  1842. assert_allclose(Y.pdf(x), stats.dgamma(a=a).pdf(x))
  1843. def test_properties(self):
  1844. components = [Normal(mu=-0.25, sigma=1.1), Normal(mu=0.5, sigma=0.9)]
  1845. weights = (0.4, 0.6)
  1846. X = Mixture(components, weights=weights)
  1847. # Replacing properties doesn't work
  1848. # Different version of Python have different messages
  1849. with pytest.raises(AttributeError):
  1850. X.components = 10
  1851. with pytest.raises(AttributeError):
  1852. X.weights = 10
  1853. # Mutation doesn't work
  1854. X.components[0] = components[1]
  1855. assert X.components[0] == components[0]
  1856. X.weights[0] = weights[1]
  1857. assert X.weights[0] == weights[0]
  1858. def test_inverse(self):
  1859. # Originally, inverse relied on the mean to start the bracket search.
  1860. # This didn't work for distributions with non-finite mean. Check that
  1861. # this is resolved.
  1862. rng = np.random.default_rng(24358934657854237863456)
  1863. Cauchy = stats.make_distribution(stats.cauchy)
  1864. X0 = Cauchy()
  1865. X = stats.Mixture([X0, X0])
  1866. p = rng.random(size=10)
  1867. np.testing.assert_allclose(X.icdf(p), X0.icdf(p))
  1868. np.testing.assert_allclose(X.iccdf(p), X0.iccdf(p))
  1869. np.testing.assert_allclose(X.ilogcdf(p), X0.ilogcdf(p))
  1870. np.testing.assert_allclose(X.ilogccdf(p), X0.ilogccdf(p))
  1871. def test_zipfian_distribution_wrapper():
  1872. # Regression test for gh-23678: calling the cdf method at the end
  1873. # point of the Zipfian distribution would generate a warning.
  1874. Zipfian = stats.make_distribution(stats.zipfian)
  1875. zdist = Zipfian(a=0.75, n=15)
  1876. # This should not generate any warnings.
  1877. assert_equal(zdist.cdf(15), 1.0)