test_quantile.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. import pytest
  2. import numpy as np
  3. from scipy import stats
  4. from scipy.stats._quantile import _xp_searchsorted
  5. from scipy._lib._array_api import (
  6. xp_default_dtype,
  7. is_numpy,
  8. is_torch,
  9. is_jax,
  10. make_xp_test_case,
  11. SCIPY_ARRAY_API,
  12. xp_size,
  13. xp_copy,
  14. )
  15. from scipy._lib._array_api_no_0d import xp_assert_close, xp_assert_equal
  16. from scipy._lib._util import _apply_over_batch
  17. skip_xp_backends = pytest.mark.skip_xp_backends
  18. lazy_xp_modules = [stats]
  19. @_apply_over_batch(('x', 1), ('p', 1))
  20. def quantile_reference_last_axis(x, p, nan_policy, method):
  21. if nan_policy == 'omit':
  22. x = x[~np.isnan(x)]
  23. p_mask = np.isnan(p)
  24. p = p.copy()
  25. p[p_mask] = 0.5
  26. if method == 'harrell-davis':
  27. # hdquantiles returns masked element if length along axis is 1 (bug)
  28. res = (np.full_like(p, x[0]) if x.size == 1
  29. else stats.mstats.hdquantiles(x, p).data)
  30. elif method.startswith('round'):
  31. res = winsor_reference_1d(np.sort(x), p, method)
  32. else:
  33. res = np.quantile(x, p, method=method)
  34. res = np.asarray(res)
  35. if nan_policy == 'propagate' and np.any(np.isnan(x)):
  36. res[:] = np.nan
  37. res[p_mask] = np.nan
  38. return res
  39. @np.vectorize(excluded={0, 2}) # type: ignore[call-arg]
  40. def winsor_reference_1d(y, p, method):
  41. # Adapted directly from the documentation
  42. # Note: `y` is the sorted data array
  43. n = len(y)
  44. if method == 'round_nearest':
  45. j = int(np.round(p * n) if p < 0.5 else np.round(n * p - 1))
  46. elif method == 'round_outward':
  47. j = int(np.floor(p * n) if p < 0.5 else np.ceil(n * p - 1))
  48. elif method == 'round_inward':
  49. j = int(np.ceil(p * n) if p < 0.5 else np.floor(n * p - 1))
  50. return y[j]
  51. def quantile_reference(x, p, *, axis, nan_policy, keepdims, method):
  52. x, p = np.moveaxis(x, axis, -1), np.moveaxis(np.atleast_1d(p), axis, -1)
  53. res = quantile_reference_last_axis(x, p, nan_policy, method)
  54. res = np.moveaxis(res, -1, axis)
  55. if not keepdims:
  56. res = np.squeeze(res, axis=axis)
  57. return res
  58. @make_xp_test_case(stats.quantile)
  59. class TestQuantile:
  60. def test_input_validation(self, xp):
  61. x = xp.asarray([1, 2, 3])
  62. p = xp.asarray(0.5)
  63. message = "`x` must have real dtype."
  64. with pytest.raises(ValueError, match=message):
  65. stats.quantile(xp.asarray([True, False]), p)
  66. with pytest.raises(ValueError):
  67. stats.quantile(xp.asarray([1+1j, 2]), p)
  68. message = "`p` must have real floating dtype."
  69. with pytest.raises(ValueError, match=message):
  70. stats.quantile(x, xp.asarray([0, 1]))
  71. message = "`weights` must have real dtype."
  72. with pytest.raises(ValueError, match=message):
  73. stats.quantile(x, p, weights=xp.astype(x, xp.complex64))
  74. message = "`axis` must be an integer or None."
  75. with pytest.raises(ValueError, match=message):
  76. stats.quantile(x, p, axis=0.5)
  77. with pytest.raises(ValueError, match=message):
  78. stats.quantile(x, p, axis=(0, -1))
  79. message = "`axis` is not compatible with the shapes of the inputs."
  80. with pytest.raises(ValueError, match=message):
  81. stats.quantile(x, p, axis=2)
  82. if not is_jax(xp): # no data-dependent input validation for lazy arrays
  83. message = "The input contains nan values"
  84. with pytest.raises(ValueError, match=message):
  85. stats.quantile(xp.asarray([xp.nan, 1, 2]), p, nan_policy='raise')
  86. message = "`method` must be one of..."
  87. with pytest.raises(ValueError, match=message):
  88. stats.quantile(x, p, method='a duck')
  89. message = "`method='harrell-davis'` does not support `weights`."
  90. with pytest.raises(ValueError, match=message):
  91. stats.quantile(x, p, weights=x, method='harrell-davis')
  92. message = "`method='round_nearest'` does not support `weights`."
  93. with pytest.raises(ValueError, match=message):
  94. stats.quantile(x, p, weights=x, method='round_nearest')
  95. message = "If specified, `keepdims` must be True or False."
  96. with pytest.raises(ValueError, match=message):
  97. stats.quantile(x, p, keepdims=42)
  98. message = "`keepdims` may be False only if the length of `p` along `axis` is 1."
  99. with pytest.raises(ValueError, match=message):
  100. stats.quantile(x, xp.asarray([0.5, 0.6]), keepdims=False)
  101. def _get_weights_x_rep(self, x, axis, rng):
  102. x = np.swapaxes(x, axis, -1)
  103. ndim = x.ndim
  104. x = np.atleast_2d(x)
  105. counts = rng.integers(10, size=x.shape[-1], dtype=np.int32)
  106. x_rep = []
  107. weights = []
  108. for x_ in x:
  109. counts_ = rng.permuted(counts)
  110. x_rep.append(np.repeat(x_, counts_))
  111. weights.append(counts_)
  112. x_rep, weights = np.stack(x_rep), np.stack(weights)
  113. if ndim < 2:
  114. x_rep, weights = np.squeeze(x_rep, axis=0), np.squeeze(weights, axis=0)
  115. x_rep, weights = np.swapaxes(x_rep, -1, axis), np.swapaxes(weights, -1, axis)
  116. weights = np.asarray(weights, dtype=x.dtype)
  117. return weights, x_rep
  118. @skip_xp_backends(cpu_only=True, reason="PyTorch doesn't have `betainc`.",
  119. exceptions=['cupy'])
  120. @pytest.mark.parametrize('method',
  121. ['inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
  122. 'hazen', 'interpolated_inverted_cdf', 'linear',
  123. 'median_unbiased', 'normal_unbiased', 'weibull',
  124. 'harrell-davis', 'round_nearest', 'round_outward', 'round_inward',
  125. '_lower', '_higher', '_midpoint', '_nearest'])
  126. @pytest.mark.parametrize('shape_x, shape_p, axis',
  127. [(10, None, -1), (10, 10, -1), (10, (2, 3), -1), ((10, 2), None, 0)])
  128. @pytest.mark.parametrize('weights', [False, True])
  129. def test_against_reference(self, method, shape_x, shape_p, axis, weights, xp):
  130. # Test all methods with various data shapes
  131. if weights and (method.startswith('_') or method.startswith('round')
  132. or method=='harrell-davis'):
  133. pytest.skip('`weights` not supported by private (legacy) methods.')
  134. dtype = xp_default_dtype(xp)
  135. rng = np.random.default_rng(23458924568734956)
  136. x = rng.random(size=shape_x)
  137. p = rng.random(size=shape_p)
  138. if weights:
  139. weights, x_rep = self._get_weights_x_rep(x, axis, rng)
  140. else:
  141. weights, x_rep = None, x
  142. ref = quantile_reference(
  143. x_rep, p, method=method[1:] if method.startswith('_') else method,
  144. axis=axis, nan_policy='propagate', keepdims=shape_p is not None)
  145. x, p = xp.asarray(x, dtype=dtype), xp.asarray(p, dtype=dtype)
  146. weights = weights if weights is None else xp.asarray(weights, dtype=dtype)
  147. res = stats.quantile(x, p, method=method, weights=weights, axis=axis)
  148. xp_assert_close(res, xp.asarray(ref, dtype=dtype))
  149. @pytest.mark.filterwarnings("ignore:torch.searchsorted:UserWarning")
  150. @skip_xp_backends(cpu_only=True, reason="PyTorch doesn't have `betainc`.",
  151. exceptions=['cupy', 'jax.numpy'])
  152. @pytest.mark.parametrize('axis', [0, 1])
  153. @pytest.mark.parametrize('keepdims', [False, True])
  154. @pytest.mark.parametrize('nan_policy', ['omit', 'propagate', 'marray'])
  155. @pytest.mark.parametrize('dtype', ['float32', 'float64'])
  156. @pytest.mark.parametrize('method', ['linear', 'harrell-davis', 'round_nearest'])
  157. @pytest.mark.parametrize('weights', [False, True])
  158. def test_against_reference_2(self, axis, keepdims, nan_policy,
  159. dtype, method, weights, xp):
  160. # Test some methods with various combinations of arguments
  161. if is_jax(xp) and nan_policy == 'marray': # mdhaber/marray#146
  162. pytest.skip("`marray` currently incompatible with JAX")
  163. if weights and method in {'harrell-davis', 'round_nearest'}:
  164. pytest.skip("These methods don't yet support weights")
  165. rng = np.random.default_rng(23458924568734956)
  166. shape = (5, 6)
  167. x = rng.random(size=shape).astype(dtype)
  168. p = rng.random(size=shape).astype(dtype)
  169. mask = rng.random(size=shape) > 0.8
  170. assert np.any(mask)
  171. x[mask] = np.nan
  172. if not keepdims:
  173. p = np.mean(p, axis=axis, keepdims=True)
  174. # inject p = 0 and p = 1 to test edge cases
  175. # Currently would fail with CuPy/JAX (cupy/cupy#8934, jax-ml/jax#21900);
  176. # remove the `if` when those are resolved.
  177. if is_numpy(xp):
  178. p0 = p.ravel()
  179. p0[1] = 0.
  180. p0[-2] = 1.
  181. dtype = getattr(xp, dtype)
  182. if weights:
  183. weights, x_rep = self._get_weights_x_rep(x, axis, rng)
  184. weights = weights if weights is None else xp.asarray(weights)
  185. else:
  186. weights, x_rep = None, x
  187. if nan_policy == 'marray':
  188. if not SCIPY_ARRAY_API:
  189. pytest.skip("MArray is only available if SCIPY_ARRAY_API=1")
  190. if weights is not None:
  191. pytest.skip("MArray is not yet compatible with weights")
  192. marray = pytest.importorskip('marray')
  193. kwargs = dict(axis=axis, keepdims=keepdims, method=method)
  194. mxp = marray._get_namespace(xp)
  195. x_mp = mxp.asarray(x, mask=mask)
  196. weights = weights if weights is None else mxp.asarray(weights)
  197. res = stats.quantile(x_mp, mxp.asarray(p), weights=weights, **kwargs)
  198. ref = quantile_reference(x_rep, p, nan_policy='omit', **kwargs)
  199. xp_assert_close(res.data, xp.asarray(ref, dtype=dtype))
  200. return
  201. kwargs = dict(axis=axis, keepdims=keepdims,
  202. nan_policy=nan_policy, method=method)
  203. res = stats.quantile(xp.asarray(x), xp.asarray(p), weights=weights, **kwargs)
  204. ref = quantile_reference(x_rep, p, **kwargs)
  205. xp_assert_close(res, xp.asarray(ref, dtype=dtype))
  206. def test_integer_input_output_dtype(self, xp):
  207. res = stats.quantile(xp.arange(10, dtype=xp.int64), 0.5)
  208. assert res.dtype == xp_default_dtype(xp)
  209. @pytest.mark.parametrize('x, p, ref, kwargs',
  210. [([], 0.5, np.nan, {}),
  211. ([1, 2, 3], [-1, 0, 1, 1.5, np.nan], [np.nan, 1, 3, np.nan, np.nan], {}),
  212. ([1, 2, 3], [], [], {}),
  213. ([[np.nan, 2]], 0.5, [np.nan, 2], {'nan_policy': 'omit'}),
  214. ([[], []], 0.5, np.full(2, np.nan), {'axis': -1}),
  215. ([[], []], 0.5, np.zeros((0,)), {'axis': 0, 'keepdims': False}),
  216. ([[], []], 0.5, np.zeros((1, 0)), {'axis': 0, 'keepdims': True}),
  217. ([], [0.5, 0.6], np.full(2, np.nan), {}),
  218. (np.arange(1, 28).reshape((3, 3, 3)), 0.5, [[[14.]]],
  219. {'axis': None, 'keepdims': True}),
  220. ([[1, 2], [3, 4]], [0.25, 0.5, 0.75], [[1.75, 2.5, 3.25]],
  221. {'axis': None, 'keepdims': True}),
  222. # Known issue:
  223. # ([1, 2, 3], 0.5, 2., {'weights': [0, 0, 0]})
  224. # See https://github.com/scipy/scipy/pull/23941#issuecomment-3503554361
  225. ])
  226. def test_edge_cases(self, x, p, ref, kwargs, xp):
  227. default_dtype = xp_default_dtype(xp)
  228. x, p, ref = xp.asarray(x), xp.asarray(p), xp.asarray(ref, dtype=default_dtype)
  229. res = stats.quantile(x, p, **kwargs)
  230. xp_assert_equal(res, ref)
  231. @pytest.mark.parametrize('axis', [0, 1, 2])
  232. @pytest.mark.parametrize('keepdims', [False, True])
  233. def test_size_0(self, axis, keepdims, xp):
  234. shape = [3, 4, 0]
  235. out_shape = shape.copy()
  236. if keepdims:
  237. out_shape[axis] = 1
  238. else:
  239. out_shape.pop(axis)
  240. res = stats.quantile(xp.zeros(tuple(shape)), 0.5, axis=axis, keepdims=keepdims)
  241. assert res.shape == tuple(out_shape)
  242. @pytest.mark.parametrize('method',
  243. ['inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
  244. '_lower', '_higher', '_midpoint', '_nearest'])
  245. def test_transition(self, method, xp):
  246. # test that values of discontinuous estimators are correct when
  247. # p*n + m - 1 is integral.
  248. if method == 'closest_observation' and np.__version__ < '2.0.1':
  249. pytest.skip('Bug in np.quantile (numpy/numpy#26656) fixed in 2.0.1')
  250. x = np.arange(8., dtype=np.float64)
  251. p = np.arange(0, 1.03125, 0.03125)
  252. res = stats.quantile(xp.asarray(x), xp.asarray(p), method=method)
  253. ref = np.quantile(x, p, method=method[1:] if method.startswith('_') else method)
  254. xp_assert_equal(res, xp.asarray(ref, dtype=xp.float64))
  255. @pytest.mark.parametrize('zero_weights', [False, True])
  256. def test_weights_against_numpy(self, zero_weights, xp):
  257. if is_numpy(xp) and xp.__version__ < "2.0":
  258. pytest.skip('`weights` not supported by NumPy < 2.0.')
  259. dtype = xp_default_dtype(xp)
  260. rng = np.random.default_rng(85468924398205602)
  261. method = 'inverted_cdf'
  262. x = rng.random(size=100)
  263. weights = rng.random(size=100)
  264. if zero_weights:
  265. weights[weights < 0.5] = 0
  266. p = np.linspace(0., 1., 300)
  267. res = stats.quantile(xp.asarray(x, dtype=dtype), xp.asarray(p, dtype=dtype),
  268. method=method, weights=xp.asarray(weights, dtype=dtype))
  269. ref = np.quantile(x, p, method=method, weights=weights)
  270. xp_assert_close(res, xp.asarray(ref, dtype=dtype))
  271. @pytest.mark.parametrize('method',
  272. ['inverted_cdf', 'averaged_inverted_cdf', 'closest_observation', 'hazen',
  273. 'interpolated_inverted_cdf', 'linear','median_unbiased', 'normal_unbiased',
  274. 'weibull'])
  275. def test_zero_weights(self, method, xp):
  276. rng = np.random.default_rng(85468924398205602)
  277. # test 1-D versus eliminating zero-weighted values
  278. n = 100
  279. x = xp.asarray(rng.random(size=n))
  280. x0 = xp_copy(x)
  281. p = xp.asarray(rng.random(size=n))
  282. i_zero = xp.asarray(rng.random(size=n) < 0.1)
  283. weights = xp.asarray(rng.random(size=n))
  284. weights = xp.where(i_zero, 0., weights)
  285. res = stats.quantile(x, p, weights=weights, method=method)
  286. ref = stats.quantile(x[~i_zero], p, weights=weights[~i_zero], method=method)
  287. xp_assert_close(res, ref)
  288. xp_assert_equal(x, x0) # no input mutation
  289. # test multi-D versus `nan_policy='omit'`
  290. shape = (5, 100)
  291. x = xp.asarray(rng.random(size=shape))
  292. x0 = xp_copy(x)
  293. p = xp.asarray(rng.random(size=shape))
  294. i_zero = xp.asarray(rng.random(size=shape) < 0.1)
  295. weights = xp.asarray(rng.random(size=shape))
  296. x_nanned = xp.where(i_zero, xp.nan, x)
  297. weights_zeroed = xp.where(i_zero, 0., weights)
  298. res = stats.quantile(x, p, weights=weights_zeroed, method=method, axis=-1)
  299. ref = stats.quantile(x_nanned, p, weights=weights,
  300. nan_policy='omit', method=method, axis=-1)
  301. xp_assert_close(res, ref)
  302. xp_assert_equal(x, x0) # no input mutation
  303. @pytest.mark.filterwarnings("ignore:torch.searchsorted:UserWarning")
  304. @pytest.mark.parametrize('method',
  305. ['inverted_cdf', 'averaged_inverted_cdf', 'closest_observation', 'hazen',
  306. 'interpolated_inverted_cdf', 'linear','median_unbiased', 'normal_unbiased',
  307. 'weibull'])
  308. @pytest.mark.parametrize('shape', [50, (50, 3)])
  309. def test_unity_weights(self, method, shape, xp):
  310. # Check that result is unchanged if all weights are `1.0`
  311. rng = np.random.default_rng(28546892439820560)
  312. x = xp.asarray(rng.random(size=shape))
  313. p = xp.asarray(rng.random(size=shape))
  314. weights = xp.ones_like(x)
  315. res = stats.quantile(x, p, weights=weights, method=method)
  316. ref = stats.quantile(x, p, method=method)
  317. xp_assert_close(res, ref)
  318. @_apply_over_batch(('a', 1), ('v', 1))
  319. def np_searchsorted(a, v, side):
  320. return np.searchsorted(a, v, side=side)
  321. @make_xp_test_case(_xp_searchsorted)
  322. class Test_XPSearchsorted:
  323. @pytest.mark.parametrize('side', ['left', 'right'])
  324. @pytest.mark.parametrize('ties', [False, True])
  325. @pytest.mark.parametrize('shape', [0, 1, 2, 10, 11, 1000, 10001,
  326. (2, 0), (0, 2), (2, 10), (2, 3, 11)])
  327. @pytest.mark.parametrize('nans_x', [False, True])
  328. @pytest.mark.parametrize('infs_x', [False, True])
  329. def test_nd(self, side, ties, shape, nans_x, infs_x, xp):
  330. if nans_x and is_torch(xp):
  331. pytest.skip('torch sorts NaNs differently')
  332. rng = np.random.default_rng(945298725498274853)
  333. if ties:
  334. x = rng.integers(5, size=shape)
  335. else:
  336. x = rng.random(shape)
  337. # float32 is to accommodate JAX - nextafter with `float64` is too small?
  338. x = np.asarray(x, dtype=np.float32)
  339. xr = np.nextafter(x, np.inf)
  340. xl = np.nextafter(x, -np.inf)
  341. x_ = np.asarray([-np.inf, np.inf, np.nan])
  342. x_ = np.broadcast_to(x_, x.shape[:-1] + (3,))
  343. y = rng.permuted(np.concatenate((xl, x, xr, x_), axis=-1), axis=-1)
  344. if nans_x:
  345. mask = rng.random(shape) < 0.1
  346. x[mask] = np.nan
  347. if infs_x:
  348. mask = rng.random(shape) < 0.1
  349. x[mask] = -np.inf
  350. mask = rng.random(shape) > 0.9
  351. x[mask] = np.inf
  352. x = np.sort(x, axis=-1)
  353. x, y = np.asarray(x, dtype=np.float64), np.asarray(y, dtype=np.float64)
  354. xp_default_int = xp.asarray(1).dtype
  355. if xp_size(x) == 0 and x.ndim > 0 and x.shape[-1] != 0:
  356. ref = xp.empty(x.shape[:-1] + (y.shape[-1],), dtype=xp_default_int)
  357. else:
  358. ref = xp.asarray(np_searchsorted(x, y, side=side), dtype=xp_default_int)
  359. x, y = xp.asarray(x), xp.asarray(y)
  360. res = _xp_searchsorted(x, y, side=side)
  361. xp_assert_equal(res, ref)