_mannwhitneyu.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. import threading
  2. import numpy as np
  3. from collections import namedtuple
  4. from scipy._lib._array_api import array_namespace, xp_capabilities, xp_size, xp_promote
  5. from scipy._lib import array_api_extra as xpx
  6. from scipy import special
  7. from scipy import stats
  8. from scipy.stats._stats_py import _rankdata
  9. from ._axis_nan_policy import _axis_nan_policy_factory, _broadcast_concatenate
  10. class _MWU:
  11. '''Distribution of MWU statistic under the null hypothesis'''
  12. def __init__(self, n1, n2):
  13. self._reset(n1, n2)
  14. def set_shapes(self, n1, n2):
  15. n1, n2 = min(n1, n2), max(n1, n2)
  16. if (n1, n2) == (self.n1, self.n2):
  17. return
  18. self.n1 = n1
  19. self.n2 = n2
  20. self.s_array = np.zeros(0, dtype=int)
  21. self.configurations = np.zeros(0, dtype=np.uint64)
  22. def reset(self):
  23. self._reset(self.n1, self.n2)
  24. def _reset(self, n1, n2):
  25. self.n1 = None
  26. self.n2 = None
  27. self.set_shapes(n1, n2)
  28. def pmf(self, k):
  29. # In practice, `pmf` is never called with k > m*n/2.
  30. # If it were, we'd exploit symmetry here:
  31. # k = np.array(k, copy=True)
  32. # k2 = m*n - k
  33. # i = k2 < k
  34. # k[i] = k2[i]
  35. pmfs = self.build_u_freqs_array(np.max(k))
  36. return pmfs[k]
  37. def cdf(self, k):
  38. '''Cumulative distribution function'''
  39. # In practice, `cdf` is never called with k > m*n/2.
  40. # If it were, we'd exploit symmetry here rather than in `sf`
  41. pmfs = self.build_u_freqs_array(np.max(k))
  42. cdfs = np.cumsum(pmfs)
  43. return cdfs[k]
  44. def sf(self, k):
  45. '''Survival function'''
  46. # Note that both CDF and SF include the PMF at k. The p-value is
  47. # calculated from the SF and should include the mass at k, so this
  48. # is desirable
  49. # Use the fact that the distribution is symmetric and sum from the left
  50. kc = np.asarray(self.n1*self.n2 - k) # complement of k
  51. i = k < kc
  52. if np.any(i):
  53. kc[i] = k[i]
  54. cdfs = np.asarray(self.cdf(kc))
  55. cdfs[i] = 1. - cdfs[i] + self.pmf(kc[i])
  56. else:
  57. cdfs = np.asarray(self.cdf(kc))
  58. return cdfs[()]
  59. # build_sigma_array and build_u_freqs_array adapted from code
  60. # by @toobaz with permission. Thanks to @andreasloe for the suggestion.
  61. # See https://github.com/scipy/scipy/pull/4933#issuecomment-1898082691
  62. def build_sigma_array(self, a):
  63. n1, n2 = self.n1, self.n2
  64. if a + 1 <= self.s_array.size:
  65. return self.s_array[1:a+1]
  66. s_array = np.zeros(a + 1, dtype=int)
  67. for d in np.arange(1, n1 + 1):
  68. # All multiples of d, except 0:
  69. indices = np.arange(d, a + 1, d)
  70. # \epsilon_d = 1:
  71. s_array[indices] += d
  72. for d in np.arange(n2 + 1, n2 + n1 + 1):
  73. # All multiples of d, except 0:
  74. indices = np.arange(d, a + 1, d)
  75. # \epsilon_d = -1:
  76. s_array[indices] -= d
  77. # We don't need 0:
  78. self.s_array = s_array
  79. return s_array[1:]
  80. def build_u_freqs_array(self, maxu):
  81. """
  82. Build all the array of frequencies for u from 0 to maxu.
  83. Assumptions:
  84. n1 <= n2
  85. maxu <= n1 * n2 / 2
  86. """
  87. n1, n2 = self.n1, self.n2
  88. total = special.binom(n1 + n2, n1)
  89. if maxu + 1 <= self.configurations.size:
  90. return self.configurations[:maxu + 1] / total
  91. s_array = self.build_sigma_array(maxu)
  92. # Start working with ints, for maximum precision and efficiency:
  93. configurations = np.zeros(maxu + 1, dtype=np.uint64)
  94. configurations_is_uint = True
  95. uint_max = np.iinfo(np.uint64).max
  96. # How many ways to have U=0? 1
  97. configurations[0] = 1
  98. for u in np.arange(1, maxu + 1):
  99. coeffs = s_array[u - 1::-1]
  100. new_val = np.dot(configurations[:u], coeffs) / u
  101. if new_val > uint_max and configurations_is_uint:
  102. # OK, we got into numbers too big for uint64.
  103. # So now we start working with floats.
  104. # By doing this since the beginning, we would have lost precision.
  105. # (And working on python long ints would be unbearably slow)
  106. configurations = configurations.astype(float)
  107. configurations_is_uint = False
  108. configurations[u] = new_val
  109. self.configurations = configurations
  110. return configurations / total
  111. # Maintain state for faster repeat calls to `mannwhitneyu`.
  112. # _MWU() is calculated once per thread and stored as an attribute on
  113. # this thread-local variable inside mannwhitneyu().
  114. _mwu_state = threading.local()
  115. def _get_mwu_z(U, n1, n2, t, continuity=True, *, xp):
  116. '''Standardized MWU statistic'''
  117. # Follows mannwhitneyu [2]
  118. mu = n1 * n2 / 2
  119. n = n1 + n2
  120. # Tie correction according to [2], "Normal approximation and tie correction"
  121. # "A more computationally-efficient form..."
  122. tie_term = xp.sum(t**3 - t, axis=-1)
  123. s = xp.sqrt(n1*n2/12 * ((n + 1) - tie_term/(n*(n-1))))
  124. numerator = U - mu
  125. # Continuity correction.
  126. # Because SF is always used to calculate the p-value, we can always
  127. # _subtract_ 0.5 for the continuity correction. This always increases the
  128. # p-value to account for the rest of the probability mass _at_ q = U.
  129. if continuity:
  130. numerator -= 0.5
  131. # no problem evaluating the norm SF at an infinity
  132. with np.errstate(divide='ignore', invalid='ignore'):
  133. z = numerator / s
  134. return z
  135. def _mwu_input_validation(x, y, use_continuity, alternative, axis, method):
  136. ''' Input validation and standardization for mannwhitneyu '''
  137. xp = array_namespace(x, y)
  138. x, y = xpx.atleast_nd(x, ndim=1), xpx.atleast_nd(y, ndim=1)
  139. if xp.any(xp.isnan(x)) or xp.any(xp.isnan(y)):
  140. raise ValueError('`x` and `y` must not contain NaNs.')
  141. if xp_size(x) == 0 or xp_size(y) == 0:
  142. raise ValueError('`x` and `y` must be of nonzero size.')
  143. x, y = xp_promote(x, y, force_floating=True, xp=xp)
  144. bools = {True, False}
  145. if use_continuity not in bools:
  146. raise ValueError(f'`use_continuity` must be one of {bools}.')
  147. alternatives = {"two-sided", "less", "greater"}
  148. alternative = alternative.lower()
  149. if alternative not in alternatives:
  150. raise ValueError(f'`alternative` must be one of {alternatives}.')
  151. axis_int = int(axis)
  152. if axis != axis_int:
  153. raise ValueError('`axis` must be an integer.')
  154. if not isinstance(method, stats.PermutationMethod):
  155. methods = {"asymptotic", "exact", "auto"}
  156. method = method.lower()
  157. if method not in methods:
  158. raise ValueError(f'`method` must be one of {methods}.')
  159. return x, y, use_continuity, alternative, axis_int, method, xp
  160. def _mwu_choose_method(n1, n2, ties):
  161. """Choose method 'asymptotic' or 'exact' depending on input size, ties"""
  162. # if both inputs are large, asymptotic is OK
  163. if n1 > 8 and n2 > 8:
  164. return "asymptotic"
  165. # if there are any ties, asymptotic is preferred
  166. if ties:
  167. return "asymptotic"
  168. return "exact"
  169. MannwhitneyuResult = namedtuple('MannwhitneyuResult', ('statistic', 'pvalue'))
  170. @xp_capabilities(cpu_only=True, # exact calculation only implemented in NumPy
  171. skip_backends=[('cupy', 'needs rankdata'),
  172. ('dask.array', 'needs rankdata')],
  173. jax_jit=False)
  174. @_axis_nan_policy_factory(MannwhitneyuResult, n_samples=2)
  175. def mannwhitneyu(x, y, use_continuity=True, alternative="two-sided",
  176. axis=0, method="auto"):
  177. r'''Perform the Mann-Whitney U rank test on two independent samples.
  178. The Mann-Whitney U test is a nonparametric test of the null hypothesis
  179. that the distribution underlying sample `x` is the same as the
  180. distribution underlying sample `y`. It is often used as a test of
  181. difference in location between distributions.
  182. Parameters
  183. ----------
  184. x, y : array-like
  185. N-d arrays of samples. The arrays must be broadcastable except along
  186. the dimension given by `axis`.
  187. use_continuity : bool, optional
  188. Whether a continuity correction (1/2) should be applied.
  189. Default is True when `method` is ``'asymptotic'``; has no effect
  190. otherwise.
  191. alternative : {'two-sided', 'less', 'greater'}, optional
  192. Defines the alternative hypothesis. Default is 'two-sided'.
  193. Let *SX(u)* and *SY(u)* be the survival functions of the
  194. distributions underlying `x` and `y`, respectively. Then the following
  195. alternative hypotheses are available:
  196. * 'two-sided': the distributions are not equal, i.e. *SX(u) ≠ SY(u)* for
  197. at least one *u*.
  198. * 'less': the distribution underlying `x` is stochastically less
  199. than the distribution underlying `y`, i.e. *SX(u) < SY(u)* for all *u*.
  200. * 'greater': the distribution underlying `x` is stochastically greater
  201. than the distribution underlying `y`, i.e. *SX(u) > SY(u)* for all *u*.
  202. Under a more restrictive set of assumptions, the alternative hypotheses
  203. can be expressed in terms of the locations of the distributions;
  204. see [5]_ section 5.1.
  205. axis : int, optional
  206. Axis along which to perform the test. Default is 0.
  207. method : {'auto', 'asymptotic', 'exact'} or `PermutationMethod` instance, optional
  208. Selects the method used to calculate the *p*-value.
  209. Default is 'auto'. The following options are available.
  210. * ``'asymptotic'``: compares the standardized test statistic
  211. against the normal distribution, correcting for ties.
  212. * ``'exact'``: computes the exact *p*-value by comparing the observed
  213. :math:`U` statistic against the exact distribution of the :math:`U`
  214. statistic under the null hypothesis. No correction is made for ties.
  215. * ``'auto'``: chooses ``'exact'`` when the size of one of the samples
  216. is less than or equal to 8 and there are no ties;
  217. chooses ``'asymptotic'`` otherwise.
  218. * `PermutationMethod` instance. In this case, the p-value
  219. is computed using `permutation_test` with the provided
  220. configuration options and other appropriate settings.
  221. Returns
  222. -------
  223. res : MannwhitneyuResult
  224. An object containing attributes:
  225. statistic : float
  226. The Mann-Whitney U statistic corresponding with sample `x`. See
  227. Notes for the test statistic corresponding with sample `y`.
  228. pvalue : float
  229. The associated *p*-value for the chosen `alternative`.
  230. Notes
  231. -----
  232. If ``U1`` is the statistic corresponding with sample `x`, then the
  233. statistic corresponding with sample `y` is
  234. ``U2 = x.shape[axis] * y.shape[axis] - U1``.
  235. `mannwhitneyu` is for independent samples. For related / paired samples,
  236. consider `scipy.stats.wilcoxon`.
  237. `method` ``'exact'`` is recommended when there are no ties and when either
  238. sample size is less than 8 [1]_. The implementation follows the algorithm
  239. reported in [3]_.
  240. Note that the exact method is *not* corrected for ties, but
  241. `mannwhitneyu` will not raise errors or warnings if there are ties in the
  242. data. If there are ties and either samples is small (fewer than ~10
  243. observations), consider passing an instance of `PermutationMethod`
  244. as the `method` to perform a permutation test.
  245. The Mann-Whitney U test is a non-parametric version of the t-test for
  246. independent samples. When the means of samples from the populations
  247. are normally distributed, consider `scipy.stats.ttest_ind`.
  248. See Also
  249. --------
  250. scipy.stats.wilcoxon, scipy.stats.ranksums, scipy.stats.ttest_ind
  251. References
  252. ----------
  253. .. [1] H.B. Mann and D.R. Whitney, "On a test of whether one of two random
  254. variables is stochastically larger than the other", The Annals of
  255. Mathematical Statistics, Vol. 18, pp. 50-60, 1947.
  256. .. [2] Mann-Whitney U Test, Wikipedia,
  257. http://en.wikipedia.org/wiki/Mann-Whitney_U_test
  258. .. [3] Andreas Löffler,
  259. "Über eine Partition der nat. Zahlen und ihr Anwendung beim U-Test",
  260. Wiss. Z. Univ. Halle, XXXII'83 pp. 87-89.
  261. .. [4] Rosie Shier, "Statistics: 2.3 The Mann-Whitney U Test", Mathematics
  262. Learning Support Centre, 2004.
  263. .. [5] Michael P. Fay and Michael A. Proschan. "Wilcoxon-Mann-Whitney
  264. or t-test? On assumptions for hypothesis tests and multiple \
  265. interpretations of decision rules." Statistics surveys, Vol. 4, pp.
  266. 1-39, 2010. https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2857732/
  267. Examples
  268. --------
  269. We follow the example from [4]_: nine randomly sampled young adults were
  270. diagnosed with type II diabetes at the ages below.
  271. >>> males = [19, 22, 16, 29, 24]
  272. >>> females = [20, 11, 17, 12]
  273. We use the Mann-Whitney U test to assess whether there is a statistically
  274. significant difference in the diagnosis age of males and females.
  275. The null hypothesis is that the distribution of male diagnosis ages is
  276. the same as the distribution of female diagnosis ages. We decide
  277. that a confidence level of 95% is required to reject the null hypothesis
  278. in favor of the alternative that the distributions are different.
  279. Since the number of samples is very small and there are no ties in the
  280. data, we can compare the observed test statistic against the *exact*
  281. distribution of the test statistic under the null hypothesis.
  282. >>> from scipy.stats import mannwhitneyu
  283. >>> U1, p = mannwhitneyu(males, females, method="exact")
  284. >>> print(U1)
  285. 17.0
  286. `mannwhitneyu` always reports the statistic associated with the first
  287. sample, which, in this case, is males. This agrees with :math:`U_M = 17`
  288. reported in [4]_. The statistic associated with the second statistic
  289. can be calculated:
  290. >>> nx, ny = len(males), len(females)
  291. >>> U2 = nx*ny - U1
  292. >>> print(U2)
  293. 3.0
  294. This agrees with :math:`U_F = 3` reported in [4]_. The two-sided
  295. *p*-value can be calculated from either statistic, and the value produced
  296. by `mannwhitneyu` agrees with :math:`p = 0.11` reported in [4]_.
  297. >>> print(p)
  298. 0.1111111111111111
  299. The exact distribution of the test statistic is asymptotically normal, so
  300. the example continues by comparing the exact *p*-value against the
  301. *p*-value produced using the normal approximation.
  302. >>> _, pnorm = mannwhitneyu(males, females, method="asymptotic")
  303. >>> print(pnorm)
  304. 0.11134688653314041
  305. Here `mannwhitneyu`'s reported *p*-value appears to conflict with the
  306. value :math:`p = 0.09` given in [4]_. The reason is that [4]_
  307. does not apply the continuity correction performed by `mannwhitneyu`;
  308. `mannwhitneyu` reduces the distance between the test statistic and the
  309. mean :math:`\mu = n_x n_y / 2` by 0.5 to correct for the fact that the
  310. discrete statistic is being compared against a continuous distribution.
  311. Here, the :math:`U` statistic used is less than the mean, so we reduce
  312. the distance by adding 0.5 in the numerator.
  313. >>> import numpy as np
  314. >>> from scipy.stats import norm
  315. >>> U = min(U1, U2)
  316. >>> N = nx + ny
  317. >>> z = (U - nx*ny/2 + 0.5) / np.sqrt(nx*ny * (N + 1)/ 12)
  318. >>> p = 2 * norm.cdf(z) # use CDF to get p-value from smaller statistic
  319. >>> print(p)
  320. 0.11134688653314041
  321. If desired, we can disable the continuity correction to get a result
  322. that agrees with that reported in [4]_.
  323. >>> _, pnorm = mannwhitneyu(males, females, use_continuity=False,
  324. ... method="asymptotic")
  325. >>> print(pnorm)
  326. 0.0864107329737
  327. Regardless of whether we perform an exact or asymptotic test, the
  328. probability of the test statistic being as extreme or more extreme by
  329. chance exceeds 5%, so we do not consider the results statistically
  330. significant.
  331. Suppose that, before seeing the data, we had hypothesized that females
  332. would tend to be diagnosed at a younger age than males.
  333. In that case, it would be natural to provide the female ages as the
  334. first input, and we would have performed a one-sided test using
  335. ``alternative = 'less'``: females are diagnosed at an age that is
  336. stochastically less than that of males.
  337. >>> res = mannwhitneyu(females, males, alternative="less", method="exact")
  338. >>> print(res)
  339. MannwhitneyuResult(statistic=3.0, pvalue=0.05555555555555555)
  340. Again, the probability of getting a sufficiently low value of the
  341. test statistic by chance under the null hypothesis is greater than 5%,
  342. so we do not reject the null hypothesis in favor of our alternative.
  343. If it is reasonable to assume that the means of samples from the
  344. populations are normally distributed, we could have used a t-test to
  345. perform the analysis.
  346. >>> from scipy.stats import ttest_ind
  347. >>> res = ttest_ind(females, males, alternative="less")
  348. >>> print(res)
  349. TtestResult(statistic=-2.239334696520584,
  350. pvalue=0.030068441095757924,
  351. df=7.0)
  352. Under this assumption, the *p*-value would be low enough to reject the
  353. null hypothesis in favor of the alternative.
  354. '''
  355. x, y, use_continuity, alternative, axis_int, method, xp = (
  356. _mwu_input_validation(x, y, use_continuity, alternative, axis, method))
  357. xy = _broadcast_concatenate((x, y), axis)
  358. n1, n2 = x.shape[-1], y.shape[-1] # _axis_nan_policy decorator ensures axis=-1
  359. # Follows [2]
  360. ranks, t = _rankdata(xy, 'average', return_ties=True) # method 2, step 1
  361. ranks = xp.astype(ranks, x.dtype, copy=False)
  362. t = xp.astype(t, x.dtype, copy=False)
  363. R1 = xp.sum(ranks[..., :n1], axis=-1) # method 2, step 2
  364. U1 = R1 - n1*(n1+1)/2 # method 2, step 3
  365. U2 = n1 * n2 - U1 # as U1 + U2 = n1 * n2
  366. if alternative == "greater":
  367. U, f = U1, 1 # U is the statistic to use for p-value, f is a factor
  368. elif alternative == "less":
  369. U, f = U2, 1 # Due to symmetry, use SF of U2 rather than CDF of U1
  370. else:
  371. U, f = xp.maximum(U1, U2), 2 # multiply SF by two for two-sided test
  372. if method == "auto":
  373. method = _mwu_choose_method(n1, n2, xp.any(t > 1))
  374. if method == "exact":
  375. if not hasattr(_mwu_state, 's'):
  376. _mwu_state.s = _MWU(0, 0)
  377. _mwu_state.s.set_shapes(n1, n2)
  378. p = xp.asarray(_mwu_state.s.sf(np.asarray(U, np.int64)), dtype=x.dtype)
  379. elif method == "asymptotic":
  380. z = _get_mwu_z(U, n1, n2, t, continuity=use_continuity, xp=xp)
  381. p = special.ndtr(-z)
  382. else: # `PermutationMethod` instance (already validated)
  383. def statistic(x, y, axis):
  384. return mannwhitneyu(x, y, use_continuity=use_continuity,
  385. alternative=alternative, axis=axis,
  386. method="asymptotic").statistic
  387. res = stats.permutation_test((x, y), statistic, axis=axis,
  388. **method._asdict(), alternative=alternative)
  389. p = res.pvalue
  390. f = 1
  391. p *= f
  392. # Ensure that test statistic is not greater than 1
  393. # This could happen for exact test when U = m*n/2
  394. p = xp.clip(p, 0., 1.)
  395. return MannwhitneyuResult(U1, p)