_wilcoxon.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import numpy as np
  2. from scipy import stats
  3. from ._stats_py import _get_pvalue, _rankdata, _SimpleNormal
  4. from . import _morestats
  5. from ._axis_nan_policy import _broadcast_arrays
  6. from ._hypotests import _get_wilcoxon_distr
  7. from scipy._lib._util import _get_nan
  8. from scipy._lib._array_api import array_namespace, xp_promote, xp_size
  9. import scipy._lib.array_api_extra as xpx
  10. class WilcoxonDistribution:
  11. def __init__(self, n):
  12. n = np.asarray(n).astype(int, copy=False)
  13. self.n = n
  14. self._dists = {ni: _get_wilcoxon_distr(ni) for ni in np.unique(n)}
  15. def _cdf1(self, k, n):
  16. pmfs = self._dists[n]
  17. return pmfs[:k + 1].sum()
  18. def _cdf(self, k, n):
  19. return np.vectorize(self._cdf1, otypes=[float])(k, n)
  20. def _sf1(self, k, n):
  21. pmfs = self._dists[n]
  22. return pmfs[k:].sum()
  23. def _sf(self, k, n):
  24. return np.vectorize(self._sf1, otypes=[float])(k, n)
  25. def mean(self):
  26. return self.n * (self.n + 1) / 4
  27. def _prep(self, k):
  28. k = np.asarray(k).astype(int, copy=False)
  29. mn = self.mean()
  30. out = np.empty(k.shape, dtype=np.float64)
  31. return k, mn, out
  32. def cdf(self, k):
  33. k, mn, out = self._prep(k)
  34. return xpx.apply_where(
  35. k <= mn, (k, self.n),
  36. self._cdf,
  37. lambda k, n: 1 - self._sf(k+1, n))[()]
  38. def sf(self, k):
  39. k, mn, out = self._prep(k)
  40. return xpx.apply_where(
  41. k <= mn, (k, self.n),
  42. self._sf,
  43. lambda k, n: 1 - self._cdf(k-1, n))[()]
  44. def _wilcoxon_iv(x, y, zero_method, correction, alternative, method, axis):
  45. xp = array_namespace(x, y)
  46. x, y = xp_promote(x, y, force_floating=True, xp=xp)
  47. axis = np.asarray(axis)[()] # OK to use NumPy for input validation
  48. message = "`axis` must be an integer."
  49. if not np.issubdtype(axis.dtype, np.integer) or axis.ndim != 0:
  50. raise ValueError(message)
  51. axis = int(axis)
  52. message = '`axis` must be compatible with the shape(s) of `x` (and `y`)'
  53. AxisError = getattr(np, 'AxisError', None) or np.exceptions.AxisError
  54. try:
  55. if y is None:
  56. d = x
  57. else:
  58. x, y = _broadcast_arrays((x, y), axis=axis, xp=xp)
  59. d = x - y
  60. d = xp.moveaxis(d, axis, -1)
  61. except AxisError as e:
  62. raise AxisError(message) from e
  63. message = "`x` and `y` must have the same length along `axis`."
  64. if y is not None and x.shape[axis] != y.shape[axis]:
  65. raise ValueError(message)
  66. message = "`x` (and `y`, if provided) must be an array of real numbers."
  67. if not xp.isdtype(d.dtype, "real floating"):
  68. raise ValueError(message)
  69. zero_method = str(zero_method).lower()
  70. zero_methods = {"wilcox", "pratt", "zsplit"}
  71. message = f"`zero_method` must be one of {zero_methods}."
  72. if zero_method not in zero_methods:
  73. raise ValueError(message)
  74. corrections = {True, False}
  75. message = f"`correction` must be one of {corrections}."
  76. if correction not in corrections:
  77. raise ValueError(message)
  78. alternative = str(alternative).lower()
  79. alternatives = {"two-sided", "less", "greater"}
  80. message = f"`alternative` must be one of {alternatives}."
  81. if alternative not in alternatives:
  82. raise ValueError(message)
  83. if not isinstance(method, stats.PermutationMethod):
  84. methods = {"auto", "asymptotic", "exact"}
  85. message = (f"`method` must be one of {methods} or "
  86. "an instance of `stats.PermutationMethod`.")
  87. if method not in methods:
  88. raise ValueError(message)
  89. output_z = True if method == 'asymptotic' else False
  90. # For small samples, we decide later whether to perform an exact test or a
  91. # permutation test. The reason is that the presence of ties is not
  92. # known at the input validation stage.
  93. n_zero = xp.count_nonzero(d == 0, axis=None)
  94. if method == "auto" and d.shape[-1] > 50:
  95. method = "asymptotic"
  96. return d, zero_method, correction, alternative, method, axis, output_z, n_zero, xp
  97. def _wilcoxon_statistic(d, method, zero_method='wilcox', *, xp):
  98. dtype = d.dtype
  99. i_zeros = (d == 0)
  100. if zero_method == 'wilcox':
  101. # Wilcoxon's method for treating zeros was to remove them from
  102. # the calculation. We do this by replacing 0s with NaNs, which
  103. # are ignored anyway.
  104. # Copy required for array-api-strict. See data-apis/array-api-extra#506.
  105. d = xpx.at(d)[i_zeros].set(xp.nan, copy=True)
  106. i_nan = xp.isnan(d)
  107. n_nan = xp.count_nonzero(i_nan, axis=-1)
  108. count = xp.astype(d.shape[-1] - n_nan, dtype)
  109. r, t = _rankdata(xp.abs(d), 'average', return_ties=True, xp=xp)
  110. r, t = xp.astype(r, dtype, copy=False), xp.astype(t, dtype, copy=False)
  111. r_plus = xp.sum(xp.astype(d > 0, dtype) * r, axis=-1)
  112. r_minus = xp.sum(xp.astype(d < 0, dtype) * r, axis=-1)
  113. has_ties = xp.any(t == 0)
  114. if zero_method == "zsplit":
  115. # The "zero-split" method for treating zeros is to add half their contribution
  116. # to r_plus and half to r_minus.
  117. # See gh-2263 for the origin of this method.
  118. r_zero_2 = xp.sum(xp.astype(i_zeros, dtype) * r, axis=-1) / 2
  119. r_plus = xpx.at(r_plus)[...].add(r_zero_2)
  120. r_minus = xpx.at(r_minus)[...].add(r_zero_2)
  121. mn = count * (count + 1.) * 0.25
  122. se = count * (count + 1.) * (2. * count + 1.)
  123. if zero_method == "pratt":
  124. # Pratt's method for treating zeros was just to modify the z-statistic.
  125. # normal approximation needs to be adjusted, see Cureton (1967)
  126. n_zero = xp.astype(xp.count_nonzero(i_zeros, axis=-1), dtype)
  127. mn = xpx.at(mn)[...].subtract(n_zero * (n_zero + 1.) * 0.25)
  128. se = xpx.at(se)[...].subtract(n_zero * (n_zero + 1.) * (2. * n_zero + 1.))
  129. # zeros are not to be included in tie-correction.
  130. # any tie counts corresponding with zeros are in the 0th column
  131. # t[xp.any(i_zeros, axis=-1), 0] = 0
  132. t_i_zeros = xp.zeros_like(i_zeros)
  133. t_i_zeros = xpx.at(t_i_zeros)[..., 0].set(xp.any(i_zeros, axis=-1))
  134. t = xpx.at(t)[t_i_zeros].set(0.)
  135. tie_correct = xp.sum(t**3 - t, axis=-1)
  136. se = xp.sqrt((se - tie_correct/2) / 24)
  137. # se = 0 means that no non-zero values are left in d. we only need z
  138. # if method is asymptotic. however, if method="auto", the switch to
  139. # asymptotic might only happen after the statistic is calculated, so z
  140. # needs to be computed. in all other cases, avoid division by zero warning
  141. # (z is not needed anyways)
  142. if method in ["asymptotic", "auto"]:
  143. z = (r_plus - mn) / se
  144. else:
  145. z = xp.nan
  146. return r_plus, r_minus, se, z, count, has_ties
  147. def _correction_sign(z, alternative, xp):
  148. if alternative == 'greater':
  149. return 1
  150. elif alternative == 'less':
  151. return -1
  152. else:
  153. return xp.sign(z)
  154. def _wilcoxon_nd(x, y=None, zero_method='wilcox', correction=True,
  155. alternative='two-sided', method='auto', axis=0):
  156. temp = _wilcoxon_iv(x, y, zero_method, correction, alternative, method, axis)
  157. d, zero_method, correction, alternative, method, axis, output_z, n_zero, xp = temp
  158. if xp_size(d) == 0:
  159. NaN = _get_nan(d, xp=xp)
  160. res = _morestats.WilcoxonResult(statistic=NaN, pvalue=NaN)
  161. if method == 'asymptotic':
  162. res.zstatistic = NaN
  163. return res
  164. r_plus, r_minus, se, z, count, has_ties = _wilcoxon_statistic(
  165. d, method, zero_method, xp=xp
  166. )
  167. # we only know if there are ties after computing the statistic and not
  168. # at the input validation stage. if the original method was auto and
  169. # the decision was to use an exact test, we override this to
  170. # a permutation test now (since method='exact' is not exact in the
  171. # presence of ties)
  172. if method == "auto":
  173. if not (has_ties or n_zero > 0):
  174. method = "exact"
  175. elif d.shape[-1] <= 13:
  176. # the possible outcomes to be simulated by the permutation test
  177. # are 2**n, where n is the sample size.
  178. # if n <= 13, the p-value is deterministic since 2**13 is less
  179. # than 9999, the default number of n_resamples
  180. method = stats.PermutationMethod()
  181. else:
  182. # if there are ties and the sample size is too large to
  183. # run a deterministic permutation test, fall back to asymptotic
  184. method = "asymptotic"
  185. if method == 'asymptotic':
  186. if correction:
  187. sign = _correction_sign(z, alternative, xp=xp)
  188. z = xpx.at(z)[...].subtract(sign * 0.5 / se)
  189. p = _get_pvalue(z, _SimpleNormal(), alternative, xp=xp)
  190. elif method == 'exact':
  191. dist = WilcoxonDistribution(count)
  192. # The null distribution in `dist` is exact only if there are no ties
  193. # or zeros. If there are ties or zeros, the statistic can be non-
  194. # integral, but the null distribution is only defined for integral
  195. # values of the statistic. Therefore, we're conservative: round
  196. # non-integral statistic up before computing CDF and down before
  197. # computing SF. This preserves symmetry w.r.t. alternatives and
  198. # order of the input arguments. See gh-19872.
  199. r_plus_np = np.asarray(r_plus)
  200. if alternative == 'less':
  201. p = dist.cdf(np.ceil(r_plus_np))
  202. elif alternative == 'greater':
  203. p = dist.sf(np.floor(r_plus_np))
  204. else:
  205. p = 2 * np.minimum(dist.sf(np.floor(r_plus_np)),
  206. dist.cdf(np.ceil(r_plus_np)))
  207. p = np.clip(p, 0, 1)
  208. p = xp.asarray(p, dtype=d.dtype)
  209. else: # `PermutationMethod` instance (already validated)
  210. p = stats.permutation_test(
  211. (d,), lambda d: _wilcoxon_statistic(d, method, zero_method, xp=xp)[0],
  212. permutation_type='samples', **method._asdict(),
  213. alternative=alternative, axis=-1).pvalue
  214. # for backward compatibility...
  215. statistic = xp.minimum(r_plus, r_minus) if alternative=='two-sided' else r_plus
  216. z = -xp.abs(z) if (alternative == 'two-sided' and method == 'asymptotic') else z
  217. statistic = statistic[()] if statistic.ndim == 0 else statistic
  218. p = p[()] if p.ndim == 0 else p
  219. res = _morestats.WilcoxonResult(statistic=statistic, pvalue=p)
  220. if output_z:
  221. res.zstatistic = z[()] if z.ndim == 0 else z
  222. return res