test_marray.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. import numpy as np
  2. import pytest
  3. from scipy import stats
  4. from packaging import version
  5. from scipy._lib._array_api import xp_assert_close, xp_assert_equal, _length_nonmasked
  6. from scipy._lib._array_api import make_xp_pytest_param, make_xp_test_case
  7. from scipy._lib._array_api import SCIPY_ARRAY_API
  8. from scipy.stats._stats_py import _xp_mean, _xp_var
  9. from scipy.stats._axis_nan_policy import _axis_nan_policy_factory
  10. marray = pytest.importorskip('marray')
  11. pytestmark = [
  12. pytest.mark.skipif(
  13. not SCIPY_ARRAY_API,
  14. reason=(
  15. "special function dispatch to marray required for these tests"
  16. " is hidden behind SCIPY_ARRAY_API flag."
  17. ),
  18. ),
  19. ]
  20. skip_backend = pytest.mark.skip_xp_backends
  21. def get_arrays(n_arrays, *, dtype='float64', xp=np, shape=(7, 8), seed=84912165484321):
  22. mxp = marray._get_namespace(xp)
  23. rng = np.random.default_rng(seed)
  24. datas, masks = [], []
  25. for i in range(n_arrays):
  26. data = rng.random(size=shape)
  27. if dtype.startswith('complex'):
  28. data = 10*data * 10j*rng.standard_normal(size=shape)
  29. data = data.astype(dtype)
  30. datas.append(data)
  31. mask = rng.random(size=shape) > 0.75
  32. masks.append(mask)
  33. marrays = []
  34. nan_arrays = []
  35. for array, mask in zip(datas, masks):
  36. marrays.append(mxp.asarray(array, mask=mask))
  37. nan_array = array.copy()
  38. nan_array[mask] = xp.nan
  39. nan_arrays.append(nan_array)
  40. return mxp, marrays, nan_arrays
  41. @skip_backend('dask.array', reason='Arrays need `device` attribute: dask/dask#11711')
  42. @skip_backend('jax.numpy', reason="JAX doesn't allow item assignment.")
  43. @skip_backend('torch', reason="marray#99")
  44. @pytest.mark.parametrize('fun, kwargs', [make_xp_pytest_param(stats.gmean, {}),
  45. make_xp_pytest_param(stats.hmean, {}),
  46. make_xp_pytest_param(stats.pmean, {'p': 2})])
  47. @pytest.mark.parametrize('axis', [0, 1])
  48. def test_xmean(fun, kwargs, axis, xp):
  49. mxp, marrays, narrays = get_arrays(2, xp=xp)
  50. res = fun(marrays[0], weights=marrays[1], axis=axis, **kwargs)
  51. ref = fun(narrays[0], weights=narrays[1], nan_policy='omit', axis=axis, **kwargs)
  52. xp_assert_close(res.data, xp.asarray(ref))
  53. @skip_backend('dask.array', reason='Arrays need `device` attribute: dask/dask#11711')
  54. @skip_backend('jax.numpy', reason="JAX doesn't allow item assignment.")
  55. @skip_backend('torch', reason="marray#99")
  56. @pytest.mark.parametrize('axis', [0, 1, None])
  57. @pytest.mark.parametrize('keepdims', [False, True])
  58. def test_xp_mean(axis, keepdims, xp):
  59. mxp, marrays, narrays = get_arrays(2, xp=xp)
  60. kwargs = dict(axis=axis, keepdims=keepdims)
  61. res = _xp_mean(marrays[0], weights=marrays[1], **kwargs)
  62. ref = _xp_mean(narrays[0], weights=narrays[1], nan_policy='omit', **kwargs)
  63. xp_assert_close(res.data, xp.asarray(ref))
  64. @skip_backend('dask.array', reason='Arrays need `device` attribute: dask/dask#11711')
  65. @skip_backend('jax.numpy', reason="JAX doesn't allow item assignment.")
  66. @skip_backend('torch', reason="array-api-compat#242")
  67. @pytest.mark.parametrize('fun, kwargs',
  68. [make_xp_pytest_param(stats.moment, {'order': 2}),
  69. make_xp_pytest_param(stats.skew, {}),
  70. make_xp_pytest_param(stats.skew, {'bias': False}),
  71. make_xp_pytest_param(stats.kurtosis, {}),
  72. make_xp_pytest_param(stats.kurtosis, {'bias': False}),
  73. make_xp_pytest_param(stats.sem, {}),
  74. make_xp_pytest_param(stats.kstat, {'n': 1}),
  75. make_xp_pytest_param(stats.kstat, {'n': 2}),
  76. make_xp_pytest_param(stats.kstat, {'n': 3}),
  77. make_xp_pytest_param(stats.kstat, {'n': 4}),
  78. make_xp_pytest_param(stats.kstatvar, {'n': 1}),
  79. make_xp_pytest_param(stats.kstatvar, {'n': 2}),
  80. make_xp_pytest_param(stats.circmean, {}),
  81. make_xp_pytest_param(stats.circvar, {}),
  82. make_xp_pytest_param(stats.circstd, {}),
  83. make_xp_pytest_param(stats.gstd, {}),
  84. make_xp_pytest_param(stats.variation, {}),
  85. (_xp_var, {}),
  86. make_xp_pytest_param(stats.tmean, {'limits': (0.1, 0.9)}),
  87. make_xp_pytest_param(stats.tvar, {'limits': (0.1, 0.9)}),
  88. make_xp_pytest_param(stats.tmin, {'lowerlimit': 0.5}),
  89. make_xp_pytest_param(stats.tmax, {'upperlimit': 0.5}),
  90. make_xp_pytest_param(stats.tstd, {'limits': (0.1, 0.9)}),
  91. make_xp_pytest_param(stats.tsem, {'limits': (0.1, 0.9)}),
  92. ])
  93. @pytest.mark.parametrize('axis', [0, 1, None])
  94. def test_several(fun, kwargs, axis, xp):
  95. mxp, marrays, narrays = get_arrays(1, xp=xp)
  96. kwargs = dict(axis=axis) | kwargs
  97. res = fun(marrays[0], **kwargs)
  98. ref = fun(narrays[0], nan_policy='omit', **kwargs)
  99. xp_assert_close(res.data, xp.asarray(ref))
  100. @make_xp_test_case(stats.describe)
  101. @skip_backend('dask.array', reason='Arrays need `device` attribute: dask/dask#11711')
  102. @skip_backend('jax.numpy', reason="JAX doesn't allow item assignment.")
  103. @skip_backend('torch', reason="array-api-compat#242")
  104. @pytest.mark.parametrize('axis', [0, 1])
  105. @pytest.mark.parametrize('kwargs', [{}])
  106. def test_describe(axis, kwargs, xp):
  107. mxp, marrays, narrays = get_arrays(1, xp=xp)
  108. kwargs = dict(axis=axis) | kwargs
  109. res = stats.describe(marrays[0], **kwargs)
  110. ref = stats.describe(narrays[0], nan_policy='omit', **kwargs)
  111. xp_assert_close(res.nobs.data, xp.asarray(ref.nobs))
  112. xp_assert_close(res.minmax[0].data, xp.asarray(ref.minmax[0].data))
  113. xp_assert_close(res.minmax[1].data, xp.asarray(ref.minmax[1].data))
  114. xp_assert_close(res.variance.data, xp.asarray(ref.variance.data))
  115. xp_assert_close(res.skewness.data, xp.asarray(ref.skewness.data))
  116. xp_assert_close(res.kurtosis.data, xp.asarray(ref.kurtosis.data))
  117. @skip_backend('dask.array', reason='Arrays need `device` attribute: dask/dask#11711')
  118. @skip_backend('jax.numpy', reason="JAX doesn't allow item assignment.")
  119. @skip_backend('torch', reason="array-api-compat#242")
  120. @pytest.mark.parametrize('fun', [make_xp_pytest_param(stats.zscore),
  121. make_xp_pytest_param(stats.gzscore),
  122. make_xp_pytest_param(stats.zmap)])
  123. @pytest.mark.parametrize('axis', [0, 1, None])
  124. def test_zscore(fun, axis, xp):
  125. mxp, marrays, narrays = (get_arrays(2, xp=xp) if fun == stats.zmap
  126. else get_arrays(1, xp=xp))
  127. res = fun(*marrays, axis=axis)
  128. ref = xp.asarray(fun(*narrays, nan_policy='omit', axis=axis))
  129. xp_assert_close(res.data[~res.mask], ref[~xp.isnan(ref)])
  130. xp_assert_equal(res.mask, marrays[0].mask)
  131. @skip_backend('dask.array', reason='Arrays need `device` attribute: dask/dask#11711')
  132. @skip_backend('jax.numpy', reason="JAX doesn't allow item assignment.")
  133. @skip_backend('torch', reason="array-api-compat#242")
  134. @skip_backend('cupy', reason="special functions won't work")
  135. @pytest.mark.parametrize('f', [make_xp_pytest_param(stats.ttest_1samp),
  136. make_xp_pytest_param(stats.ttest_rel),
  137. make_xp_pytest_param(stats.ttest_ind)])
  138. @pytest.mark.parametrize('axis', [0, 1, None])
  139. def test_ttest(f, axis, xp):
  140. f_name = f.__name__
  141. mxp, marrays, narrays = get_arrays(2, xp=xp)
  142. if f_name == 'ttest_1samp':
  143. marrays[1] = mxp.mean(marrays[1], axis=axis, keepdims=axis is not None)
  144. narrays[1] = np.nanmean(narrays[1], axis=axis, keepdims=axis is not None)
  145. res = f(*marrays, axis=axis)
  146. ref = f(*narrays, nan_policy='omit', axis=axis)
  147. xp_assert_close(res.statistic.data, xp.asarray(ref.statistic))
  148. xp_assert_close(res.pvalue.data, xp.asarray(ref.pvalue))
  149. res_ci = res.confidence_interval()
  150. ref_ci = ref.confidence_interval()
  151. xp_assert_close(res_ci.low.data, xp.asarray(ref_ci.low))
  152. xp_assert_close(res_ci.high.data, xp.asarray(ref_ci.high))
  153. @skip_backend('dask.array', reason='Arrays need `device` attribute: dask/dask#11711')
  154. @skip_backend('jax.numpy', reason="JAX doesn't allow item assignment.")
  155. @skip_backend('torch', reason="array-api-compat#242")
  156. @skip_backend('cupy', reason="special functions won't work")
  157. @pytest.mark.filterwarnings("ignore::scipy.stats._axis_nan_policy.SmallSampleWarning")
  158. @pytest.mark.parametrize('f', [make_xp_pytest_param(stats.skewtest),
  159. make_xp_pytest_param(stats.kurtosistest),
  160. make_xp_pytest_param(stats.normaltest),
  161. make_xp_pytest_param(stats.jarque_bera)])
  162. @pytest.mark.parametrize('axis', [0, 1, None])
  163. def test_normality_tests(f, axis, xp):
  164. mxp, marrays, narrays = get_arrays(1, xp=xp, shape=(10, 11))
  165. res = f(*marrays, axis=axis)
  166. ref = f(*narrays, nan_policy='omit', axis=axis)
  167. xp_assert_close(res.statistic.data, xp.asarray(ref.statistic))
  168. xp_assert_close(res.pvalue.data, xp.asarray(ref.pvalue))
  169. def pd_nsamples(kwargs):
  170. return 2 if kwargs.get('f_exp', None) is not None else 1
  171. @_axis_nan_policy_factory(lambda *args: tuple(args), paired=True, n_samples=pd_nsamples)
  172. def power_divergence_ref(f_obs, f_exp=None, *, ddof, lambda_, axis=0):
  173. return stats.power_divergence(f_obs, f_exp, axis=axis, ddof=ddof, lambda_=lambda_)
  174. @make_xp_test_case(stats.chisquare, stats.power_divergence)
  175. @skip_backend('dask.array', reason='Arrays need `device` attribute: dask/dask#11711')
  176. @skip_backend('jax.numpy', reason="JAX doesn't allow item assignment.")
  177. @skip_backend('torch', reason="array-api-compat#242")
  178. @skip_backend('cupy', reason="special functions won't work")
  179. @pytest.mark.parametrize('lambda_', ['pearson', 'log-likelihood', 'freeman-tukey',
  180. 'mod-log-likelihood', 'neyman', 'cressie-read',
  181. 'chisquare'])
  182. @pytest.mark.parametrize('ddof', [0, 1])
  183. @pytest.mark.parametrize('axis', [0, 1, None])
  184. def test_power_divergence_chisquare(lambda_, ddof, axis, xp):
  185. mxp, marrays, narrays = get_arrays(2, xp=xp, shape=(5, 6))
  186. kwargs = dict(axis=axis, ddof=ddof)
  187. if lambda_ == 'chisquare':
  188. lambda_ = "pearson"
  189. def f(*args, **kwargs):
  190. return stats.chisquare(*args, **kwargs)
  191. else:
  192. def f(*args, **kwargs):
  193. return stats.power_divergence(*args, lambda_=lambda_, **kwargs)
  194. # test 1-arg
  195. res = f(marrays[0], **kwargs)
  196. ref = power_divergence_ref(narrays[0], nan_policy='omit', lambda_=lambda_, **kwargs)
  197. xp_assert_close(res.statistic.data, xp.asarray(ref[0]))
  198. xp_assert_close(res.pvalue.data, xp.asarray(ref[1]))
  199. # test 2-arg
  200. common_mask = np.isnan(narrays[0]) | np.isnan(narrays[1])
  201. normalize = (np.nansum(narrays[1] * ~common_mask, axis=axis, keepdims=True)
  202. / np.nansum(narrays[0] * ~common_mask, axis=axis, keepdims=True))
  203. marrays[0] *= xp.asarray(normalize)
  204. narrays[0] *= normalize
  205. res = f(*marrays, **kwargs)
  206. ref = power_divergence_ref(*narrays, nan_policy='omit', lambda_=lambda_, **kwargs)
  207. xp_assert_close(res.statistic.data, xp.asarray(ref[0]))
  208. xp_assert_close(res.pvalue.data, xp.asarray(ref[1]))
  209. @make_xp_test_case(stats.combine_pvalues)
  210. @skip_backend('dask.array', reason='Arrays need `device` attribute: dask/dask#11711')
  211. @skip_backend('jax.numpy', reason="JAX doesn't allow item assignment.")
  212. @skip_backend('torch', reason="array-api-compat#242")
  213. @skip_backend('cupy', reason="special functions won't work")
  214. @pytest.mark.parametrize('method', ['fisher', 'pearson', 'mudholkar_george',
  215. 'tippett', 'stouffer'])
  216. @pytest.mark.parametrize('axis', [0, 1, None])
  217. def test_combine_pvalues(method, axis, xp):
  218. mxp, marrays, narrays = get_arrays(2, xp=xp, shape=(10, 11))
  219. kwargs = dict(method=method, axis=axis)
  220. res = stats.combine_pvalues(marrays[0], **kwargs)
  221. ref = stats.combine_pvalues(narrays[0], nan_policy='omit', **kwargs)
  222. xp_assert_close(res.statistic.data, xp.asarray(ref.statistic))
  223. xp_assert_close(res.pvalue.data, xp.asarray(ref.pvalue))
  224. if method != 'stouffer':
  225. return
  226. res = stats.combine_pvalues(marrays[0], weights=marrays[1], **kwargs)
  227. ref = stats.combine_pvalues(narrays[0], weights=narrays[1],
  228. nan_policy='omit', **kwargs)
  229. xp_assert_close(res.statistic.data, xp.asarray(ref.statistic))
  230. xp_assert_close(res.pvalue.data, xp.asarray(ref.pvalue))
  231. @make_xp_test_case(stats.ttest_ind_from_stats)
  232. @skip_backend('dask.array', reason='Arrays need `device` attribute: dask/dask#11711')
  233. @skip_backend('jax.numpy', reason="JAX doesn't allow item assignment.")
  234. @skip_backend('torch', reason="array-api-compat#242")
  235. @skip_backend('cupy', reason="special functions won't work")
  236. def test_ttest_ind_from_stats(xp):
  237. shape = (10, 11)
  238. mxp, marrays, narrays = get_arrays(6, xp=xp, shape=shape)
  239. mask = np.sum(np.stack([np.isnan(arg) for arg in narrays]), axis=0).astype(bool)
  240. narrays = [arg[~mask] for arg in narrays]
  241. marrays[2], marrays[5] = marrays[2] * 100, marrays[5] * 100
  242. narrays[2], narrays[5] = narrays[2] * 100, narrays[5] * 100
  243. res = stats.ttest_ind_from_stats(*marrays)
  244. ref = stats.ttest_ind_from_stats(*narrays)
  245. mask = xp.asarray(mask)
  246. assert xp.any(mask) and xp.any(~mask)
  247. xp_assert_close(res.statistic.data[~mask], xp.asarray(ref.statistic))
  248. xp_assert_close(res.pvalue.data[~mask], xp.asarray(ref.pvalue))
  249. xp_assert_close(res.statistic.mask, mask)
  250. xp_assert_close(res.pvalue.mask, mask)
  251. assert res.statistic.shape == shape
  252. assert res.pvalue.shape == shape
  253. @pytest.mark.skipif(version.parse(np.__version__) < version.parse("2"),
  254. reason="Call to _getnamespace fails with AttributeError")
  255. def test_length_nonmasked_marray_iterable_axis_raises():
  256. xp = marray._get_namespace(np)
  257. data = [[1.0, 2.0], [3.0, 4.0]]
  258. mask = [[False, False], [True, False]]
  259. marr = xp.asarray(data, mask=mask)
  260. # Axis tuples are not currently supported for MArray input.
  261. # This test can be removed after support is added.
  262. with pytest.raises(NotImplementedError,
  263. match="`axis` must be an integer or None for use with `MArray`"):
  264. _length_nonmasked(marr, axis=(0, 1), xp=xp)
  265. @make_xp_test_case(stats.directional_stats)
  266. @skip_backend('dask.array', reason='Arrays need `device` attribute: dask/dask#11711')
  267. @skip_backend('jax.numpy', reason="JAX doesn't allow item assignment.")
  268. @skip_backend('torch', reason="array-api-compat#242")
  269. @pytest.mark.filterwarnings("ignore::RuntimeWarning") # mdhaber/marray#120
  270. def test_directional_stats(xp):
  271. mxp, marrays, narrays = get_arrays(1, shape=(100, 3), xp=xp)
  272. res = stats.directional_stats(*marrays)
  273. narrays[0] = narrays[0][~np.any(np.isnan(narrays[0]), axis=1)]
  274. ref = stats.directional_stats(*narrays)
  275. xp_assert_close(res.mean_direction.data, xp.asarray(ref.mean_direction))
  276. xp_assert_close(res.mean_resultant_length.data,
  277. xp.asarray(ref.mean_resultant_length))
  278. assert not xp.any(res.mean_direction.mask)
  279. assert not xp.any(res.mean_resultant_length.mask)
  280. @make_xp_test_case(stats.bartlett)
  281. @skip_backend('dask.array', reason='Arrays need `device` attribute: dask/dask#11711')
  282. @skip_backend('jax.numpy', reason="JAX doesn't allow item assignment.")
  283. @skip_backend('torch', reason="array-api-compat#242")
  284. @skip_backend('cupy', reason="special functions won't work")
  285. @pytest.mark.parametrize('fun, kwargs', [
  286. (stats.bartlett, {}),
  287. (stats.f_oneway, {'equal_var': True}),
  288. (stats.f_oneway, {'equal_var': False}),
  289. ])
  290. @pytest.mark.parametrize('axis', [0, 1, None])
  291. def test_k_sample_tests(fun, kwargs, axis, xp):
  292. mxp, marrays, narrays = get_arrays(3, xp=xp)
  293. res = fun(*marrays, axis=axis, **kwargs)
  294. ref = fun(*narrays, nan_policy='omit', axis=axis, **kwargs)
  295. xp_assert_close(res.statistic.data, xp.asarray(ref.statistic))
  296. xp_assert_close(res.pvalue.data, xp.asarray(ref.pvalue))
  297. @skip_backend('dask.array', reason='Arrays need `device` attribute: dask/dask#11711')
  298. @skip_backend('jax.numpy', reason="JAX doesn't allow item assignment.")
  299. @skip_backend('torch', reason="array-api-compat#242")
  300. @skip_backend('cupy', reason="special functions won't work")
  301. @pytest.mark.parametrize('f', [make_xp_pytest_param(stats.pearsonr),
  302. make_xp_pytest_param(stats.pointbiserialr)])
  303. def test_pearsonr(f, xp):
  304. mxp, marrays, narrays = get_arrays(2, shape=(25,), xp=xp)
  305. res = f(*marrays)
  306. x, y = narrays
  307. mask = np.isnan(x) | np.isnan(y)
  308. ref = f(x[~mask], y[~mask])
  309. xp_assert_close(res.statistic.data, xp.asarray(ref.statistic))
  310. xp_assert_close(res.pvalue.data, xp.asarray(ref.pvalue))
  311. if f == stats.pearsonr:
  312. res_ci_low, res_ci_high = res.confidence_interval()
  313. ref_ci_low, ref_ci_high = ref.confidence_interval()
  314. xp_assert_close(res_ci_low.data, xp.asarray(ref_ci_low))
  315. xp_assert_close(res_ci_high.data, xp.asarray(ref_ci_high))
  316. @make_xp_test_case(stats.entropy)
  317. @skip_backend('dask.array', reason='Arrays need `device` attribute: dask/dask#11711')
  318. @skip_backend('jax.numpy', reason="JAX doesn't allow item assignment.")
  319. @pytest.mark.parametrize('qk', [False, True])
  320. @pytest.mark.parametrize('axis', [0, 1, None])
  321. def test_entropy(qk, axis, xp):
  322. mxp, marrays, narrays = get_arrays(2 if qk else 1, xp=xp)
  323. res = stats.entropy(*marrays, axis=axis)
  324. ref = stats.entropy(*narrays, nan_policy='omit', axis=axis)
  325. xp_assert_close(res.data, xp.asarray(ref))