common_tests.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. import pickle
  2. import numpy as np
  3. import numpy.testing as npt
  4. from numpy.testing import assert_allclose, assert_equal
  5. from pytest import raises as assert_raises
  6. import numpy.ma.testutils as ma_npt
  7. from scipy._lib._util import (
  8. getfullargspec_no_self as _getfullargspec, np_long
  9. )
  10. from scipy._lib._array_api_no_0d import xp_assert_equal
  11. from scipy import stats
  12. def check_named_results(res, attributes, ma=False, xp=None):
  13. for i, attr in enumerate(attributes):
  14. if ma:
  15. ma_npt.assert_equal(res[i], getattr(res, attr))
  16. elif xp is not None:
  17. xp_assert_equal(res[i], getattr(res, attr))
  18. else:
  19. npt.assert_equal(res[i], getattr(res, attr))
  20. def check_normalization(distfn, args, distname):
  21. norm_moment = distfn.moment(0, *args)
  22. npt.assert_allclose(norm_moment, 1.0)
  23. if distname == "rv_histogram_instance":
  24. atol, rtol = 1e-5, 0
  25. else:
  26. atol, rtol = 1e-7, 1e-7
  27. normalization_expect = distfn.expect(lambda x: 1, args=args)
  28. npt.assert_allclose(normalization_expect, 1.0, atol=atol, rtol=rtol,
  29. err_msg=distname, verbose=True)
  30. _a, _b = distfn.support(*args)
  31. normalization_cdf = distfn.cdf(_b, *args)
  32. npt.assert_allclose(normalization_cdf, 1.0)
  33. def check_moment(distfn, arg, m, v, msg):
  34. m1 = distfn.moment(1, *arg)
  35. m2 = distfn.moment(2, *arg)
  36. if not np.isinf(m):
  37. npt.assert_almost_equal(m1, m, decimal=10,
  38. err_msg=msg + ' - 1st moment')
  39. else: # or np.isnan(m1),
  40. npt.assert_(np.isinf(m1),
  41. msg + f' - 1st moment -infinite, m1={str(m1)}')
  42. if not np.isinf(v):
  43. npt.assert_almost_equal(m2 - m1 * m1, v, decimal=10,
  44. err_msg=msg + ' - 2ndt moment')
  45. else: # or np.isnan(m2),
  46. npt.assert_(np.isinf(m2), msg + f' - 2nd moment -infinite, {m2=}')
  47. def check_mean_expect(distfn, arg, m, msg):
  48. if np.isfinite(m):
  49. m1 = distfn.expect(lambda x: x, arg)
  50. npt.assert_almost_equal(m1, m, decimal=5,
  51. err_msg=msg + ' - 1st moment (expect)')
  52. def check_var_expect(distfn, arg, m, v, msg):
  53. dist_looser_tolerances = {"rv_histogram_instance" , "ksone"}
  54. kwargs = {'rtol': 5e-6} if msg in dist_looser_tolerances else {}
  55. if np.isfinite(v):
  56. m2 = distfn.expect(lambda x: x*x, arg)
  57. npt.assert_allclose(m2, v + m*m, **kwargs)
  58. def check_skew_expect(distfn, arg, m, v, s, msg):
  59. if np.isfinite(s):
  60. m3e = distfn.expect(lambda x: np.power(x-m, 3), arg)
  61. npt.assert_almost_equal(m3e, s * np.power(v, 1.5),
  62. decimal=5, err_msg=msg + ' - skew')
  63. else:
  64. npt.assert_(np.isnan(s))
  65. def check_kurt_expect(distfn, arg, m, v, k, msg):
  66. if np.isfinite(k):
  67. m4e = distfn.expect(lambda x: np.power(x-m, 4), arg)
  68. npt.assert_allclose(m4e, (k + 3.) * np.power(v, 2),
  69. atol=1e-5, rtol=1e-5,
  70. err_msg=msg + ' - kurtosis')
  71. elif not np.isposinf(k):
  72. npt.assert_(np.isnan(k))
  73. def check_munp_expect(dist, args, msg):
  74. # If _munp is overridden, test a higher moment. (Before gh-18634, some
  75. # distributions had issues with moments 5 and higher.)
  76. if dist._munp.__func__ != stats.rv_continuous._munp:
  77. res = dist.moment(5, *args) # shouldn't raise an error
  78. ref = dist.expect(lambda x: x ** 5, args, lb=-np.inf, ub=np.inf)
  79. if not np.isfinite(res): # could be valid; automated test can't know
  80. return
  81. # loose tolerance, mostly to see whether _munp returns *something*
  82. assert_allclose(res, ref, atol=1e-10, rtol=1e-4,
  83. err_msg=msg + ' - higher moment / _munp')
  84. def check_entropy(distfn, arg, msg):
  85. ent = distfn.entropy(*arg)
  86. npt.assert_(not np.isnan(ent), msg + 'test Entropy is nan')
  87. def check_private_entropy(distfn, args, superclass):
  88. # compare a generic _entropy with the distribution-specific implementation
  89. npt.assert_allclose(distfn._entropy(*args),
  90. superclass._entropy(distfn, *args))
  91. def check_entropy_vect_scale(distfn, arg):
  92. # check 2-d
  93. sc = np.asarray([[1, 2], [3, 4]])
  94. v_ent = distfn.entropy(*arg, scale=sc)
  95. s_ent = [distfn.entropy(*arg, scale=s) for s in sc.ravel()]
  96. s_ent = np.asarray(s_ent).reshape(v_ent.shape)
  97. assert_allclose(v_ent, s_ent, atol=1e-14)
  98. # check invalid value, check cast
  99. sc = [1, 2, -3]
  100. v_ent = distfn.entropy(*arg, scale=sc)
  101. s_ent = [distfn.entropy(*arg, scale=s) for s in sc]
  102. s_ent = np.asarray(s_ent).reshape(v_ent.shape)
  103. assert_allclose(v_ent, s_ent, atol=1e-14)
  104. def check_edge_support(distfn, args):
  105. # Make sure that x=self.a and self.b are handled correctly.
  106. x = distfn.support(*args)
  107. if isinstance(distfn, stats.rv_discrete):
  108. x = x[0]-1, x[1]
  109. npt.assert_equal(distfn.cdf(x, *args), [0.0, 1.0])
  110. npt.assert_equal(distfn.sf(x, *args), [1.0, 0.0])
  111. if distfn.name not in ('skellam', 'dlaplace'):
  112. # with a = -inf, log(0) generates warnings
  113. npt.assert_equal(distfn.logcdf(x, *args), [-np.inf, 0.0])
  114. npt.assert_equal(distfn.logsf(x, *args), [0.0, -np.inf])
  115. npt.assert_equal(distfn.ppf([0.0, 1.0], *args), x)
  116. npt.assert_equal(distfn.isf([0.0, 1.0], *args), x[::-1])
  117. # out-of-bounds for isf & ppf
  118. npt.assert_(np.isnan(distfn.isf([-1, 2], *args)).all())
  119. npt.assert_(np.isnan(distfn.ppf([-1, 2], *args)).all())
  120. def check_named_args(distfn, x, shape_args, defaults, meths):
  121. ## Check calling w/ named arguments.
  122. # check consistency of shapes, numargs and _parse signature
  123. signature = _getfullargspec(distfn._parse_args)
  124. npt.assert_(signature.varargs is None)
  125. npt.assert_(signature.varkw is None)
  126. npt.assert_(not signature.kwonlyargs)
  127. npt.assert_(list(signature.defaults) == list(defaults))
  128. shape_argnames = signature.args[:-len(defaults)] # a, b, loc=0, scale=1
  129. if distfn.shapes:
  130. shapes_ = distfn.shapes.replace(',', ' ').split()
  131. else:
  132. shapes_ = ''
  133. npt.assert_(len(shapes_) == distfn.numargs)
  134. npt.assert_(len(shapes_) == len(shape_argnames))
  135. # check calling w/ named arguments
  136. shape_args = list(shape_args)
  137. vals = [meth(x, *shape_args) for meth in meths]
  138. npt.assert_(np.all(np.isfinite(vals)))
  139. names, a, k = shape_argnames[:], shape_args[:], {}
  140. while names:
  141. k.update({names.pop(): a.pop()})
  142. v = [meth(x, *a, **k) for meth in meths]
  143. npt.assert_array_equal(vals, v)
  144. if 'n' not in k.keys():
  145. # `n` is first parameter of moment(), so can't be used as named arg
  146. npt.assert_equal(distfn.moment(1, *a, **k),
  147. distfn.moment(1, *shape_args))
  148. # unknown arguments should not go through:
  149. k.update({'kaboom': 42})
  150. assert_raises(TypeError, distfn.cdf, x, **k)
  151. def check_random_state_property(distfn, args):
  152. # check the random_state attribute of a distribution *instance*
  153. # baseline: this relies on the global state
  154. np.random.seed(1234) # valid use of np.random.seed
  155. distfn.random_state = None
  156. r0 = distfn.rvs(*args, size=8)
  157. # use an explicit instance-level random_state
  158. distfn.random_state = 1234
  159. r1 = distfn.rvs(*args, size=8)
  160. npt.assert_equal(r0, r1)
  161. distfn.random_state = np.random.RandomState(1234)
  162. r2 = distfn.rvs(*args, size=8)
  163. npt.assert_equal(r0, r2)
  164. # check that np.random.Generator can be used (numpy >= 1.17)
  165. if hasattr(np.random, 'default_rng'):
  166. # obtain a np.random.Generator object
  167. rng = np.random.default_rng(1234)
  168. distfn.rvs(*args, size=1, random_state=rng)
  169. # can override the instance-level random_state for an individual .rvs call
  170. distfn.random_state = 2
  171. orig_state = distfn.random_state.get_state()
  172. r3 = distfn.rvs(*args, size=8, random_state=np.random.RandomState(1234))
  173. npt.assert_equal(r0, r3)
  174. # ... and that does not alter the instance-level random_state!
  175. npt.assert_equal(distfn.random_state.get_state(), orig_state)
  176. def check_meth_dtype(distfn, arg, meths):
  177. q0 = [0.25, 0.5, 0.75]
  178. x0 = distfn.ppf(q0, *arg)
  179. x_cast = [x0.astype(tp) for tp in (np_long, np.float16, np.float32,
  180. np.float64)]
  181. for x in x_cast:
  182. # casting may have clipped the values, exclude those
  183. distfn._argcheck(*arg)
  184. x = x[(distfn.a < x) & (x < distfn.b)]
  185. for meth in meths:
  186. val = meth(x, *arg)
  187. npt.assert_(val.dtype == np.float64)
  188. def check_ppf_dtype(distfn, arg):
  189. q0 = np.asarray([0.25, 0.5, 0.75])
  190. q_cast = [q0.astype(tp) for tp in (np.float16, np.float32, np.float64)]
  191. for q in q_cast:
  192. for meth in [distfn.ppf, distfn.isf]:
  193. val = meth(q, *arg)
  194. npt.assert_(val.dtype == np.float64)
  195. def check_cmplx_deriv(distfn, arg):
  196. # Distributions allow complex arguments.
  197. def deriv(f, x, *arg):
  198. x = np.asarray(x)
  199. h = 1e-10
  200. return (f(x + h*1j, *arg)/h).imag
  201. x0 = distfn.ppf([0.25, 0.51, 0.75], *arg)
  202. x_cast = [x0.astype(tp) for tp in (np_long, np.float16, np.float32,
  203. np.float64)]
  204. for x in x_cast:
  205. # casting may have clipped the values, exclude those
  206. distfn._argcheck(*arg)
  207. x = x[(distfn.a < x) & (x < distfn.b)]
  208. pdf, cdf, sf = distfn.pdf(x, *arg), distfn.cdf(x, *arg), distfn.sf(x, *arg)
  209. assert_allclose(deriv(distfn.cdf, x, *arg), pdf, rtol=1e-5)
  210. assert_allclose(deriv(distfn.logcdf, x, *arg), pdf/cdf, rtol=1e-5)
  211. assert_allclose(deriv(distfn.sf, x, *arg), -pdf, rtol=1e-5)
  212. assert_allclose(deriv(distfn.logsf, x, *arg), -pdf/sf, rtol=1e-5)
  213. assert_allclose(deriv(distfn.logpdf, x, *arg),
  214. deriv(distfn.pdf, x, *arg) / distfn.pdf(x, *arg),
  215. rtol=1e-5)
  216. def check_pickling(distfn, args):
  217. # check that a distribution instance pickles and unpickles
  218. # pay special attention to the random_state property
  219. # save the random_state (restore later)
  220. rndm = distfn.random_state
  221. # check unfrozen
  222. distfn.random_state = 1234
  223. distfn.rvs(*args, size=8)
  224. s = pickle.dumps(distfn)
  225. r0 = distfn.rvs(*args, size=8)
  226. unpickled = pickle.loads(s)
  227. r1 = unpickled.rvs(*args, size=8)
  228. npt.assert_equal(r0, r1)
  229. # also smoke test some methods
  230. medians = [distfn.ppf(0.5, *args), unpickled.ppf(0.5, *args)]
  231. npt.assert_equal(medians[0], medians[1])
  232. npt.assert_equal(distfn.cdf(medians[0], *args),
  233. unpickled.cdf(medians[1], *args))
  234. # check frozen pickling/unpickling with rvs
  235. frozen_dist = distfn(*args)
  236. pkl = pickle.dumps(frozen_dist)
  237. unpickled = pickle.loads(pkl)
  238. r0 = frozen_dist.rvs(size=8)
  239. r1 = unpickled.rvs(size=8)
  240. npt.assert_equal(r0, r1)
  241. # check pickling/unpickling of .fit method
  242. if hasattr(distfn, "fit"):
  243. fit_function = distfn.fit
  244. pickled_fit_function = pickle.dumps(fit_function)
  245. unpickled_fit_function = pickle.loads(pickled_fit_function)
  246. assert fit_function.__name__ == unpickled_fit_function.__name__ == "fit"
  247. # restore the random_state
  248. distfn.random_state = rndm
  249. def check_freezing(distfn, args):
  250. # regression test for gh-11089: freezing a distribution fails
  251. # if loc and/or scale are specified
  252. if isinstance(distfn, stats.rv_continuous):
  253. locscale = {'loc': 1, 'scale': 2}
  254. else:
  255. locscale = {'loc': 1}
  256. rv = distfn(*args, **locscale)
  257. assert rv.a == distfn(*args).a
  258. assert rv.b == distfn(*args).b
  259. def check_rvs_broadcast(distfunc, distname, allargs, shape, shape_only, otype):
  260. rng = np.random.RandomState(123)
  261. sample = distfunc.rvs(*allargs, random_state=rng)
  262. assert_equal(sample.shape, shape, f"{distname}: rvs failed to broadcast")
  263. if not shape_only:
  264. rvs = np.vectorize(
  265. lambda *allargs: distfunc.rvs(*allargs, random_state=rng),
  266. otypes=otype)
  267. rng = np.random.RandomState(123)
  268. expected = rvs(*allargs)
  269. assert_allclose(sample, expected, rtol=1e-13)