_axis_nan_policy.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716
  1. # Many scipy.stats functions support `axis` and `nan_policy` parameters.
  2. # When the two are combined, it can be tricky to get all the behavior just
  3. # right. This file contains utility functions useful for scipy.stats functions
  4. # that support `axis` and `nan_policy`, including a decorator that
  5. # automatically adds `axis` and `nan_policy` arguments to a function.
  6. import math
  7. import warnings
  8. import numpy as np
  9. from functools import wraps
  10. from scipy._lib._array_api import xp_ravel
  11. from scipy._lib._docscrape import FunctionDoc, Parameter
  12. from scipy._lib._util import _contains_nan, AxisError, _get_nan
  13. from scipy._lib._array_api import (array_namespace, is_numpy, xp_size, xp_copy,
  14. xp_promote, is_lazy_array)
  15. import scipy._lib.array_api_extra as xpx
  16. import inspect
  17. too_small_1d_not_omit = (
  18. "One or more sample arguments is too small; all "
  19. "returned values will be NaN. "
  20. "See documentation for sample size requirements.")
  21. too_small_1d_omit = (
  22. "After omitting NaNs, one or more sample arguments "
  23. "is too small; all returned values will be NaN. "
  24. "See documentation for sample size requirements.")
  25. too_small_nd_not_omit = (
  26. "All axis-slices of one or more sample arguments are "
  27. "too small; all elements of returned arrays will be NaN. "
  28. "See documentation for sample size requirements.")
  29. too_small_nd_omit = (
  30. "After omitting NaNs, one or more axis-slices of one "
  31. "or more sample arguments is too small; corresponding "
  32. "elements of returned arrays will be NaN. "
  33. "See documentation for sample size requirements.")
  34. class SmallSampleWarning(RuntimeWarning):
  35. pass
  36. def _broadcast_arrays(arrays, axis=None, xp=None):
  37. """
  38. Broadcast shapes of arrays, ignoring incompatibility of specified axes
  39. """
  40. arrays = tuple(arrays)
  41. if not arrays:
  42. return arrays
  43. xp = array_namespace(*arrays) if xp is None else xp
  44. arrays = [xp.asarray(arr) for arr in arrays]
  45. shapes = [arr.shape for arr in arrays]
  46. new_shapes = _broadcast_shapes(shapes, axis)
  47. if axis is None:
  48. new_shapes = [new_shapes]*len(arrays)
  49. return [xp.broadcast_to(array, new_shape)
  50. for array, new_shape in zip(arrays, new_shapes)]
  51. def _broadcast_shapes(shapes, axis=None):
  52. """
  53. Broadcast shapes, ignoring incompatibility of specified axes
  54. """
  55. if not shapes:
  56. return shapes
  57. # input validation
  58. if axis is not None:
  59. axis = np.atleast_1d(axis)
  60. message = '`axis` must be an integer, a tuple of integers, or `None`.'
  61. try:
  62. with np.errstate(invalid='ignore'):
  63. axis_int = axis.astype(int)
  64. except ValueError as e:
  65. raise AxisError(message) from e
  66. if not np.array_equal(axis_int, axis):
  67. raise AxisError(message)
  68. axis = axis_int
  69. # First, ensure all shapes have same number of dimensions by prepending 1s.
  70. n_dims = max([len(shape) for shape in shapes])
  71. new_shapes = np.ones((len(shapes), n_dims), dtype=int)
  72. for row, shape in zip(new_shapes, shapes):
  73. row[len(row)-len(shape):] = shape # can't use negative indices (-0:)
  74. # Remove the shape elements of the axes to be ignored, but remember them.
  75. if axis is not None:
  76. axis[axis < 0] = n_dims + axis[axis < 0]
  77. axis = np.sort(axis)
  78. if axis[-1] >= n_dims or axis[0] < 0:
  79. message = (f"`axis` is out of bounds "
  80. f"for array of dimension {n_dims}")
  81. raise AxisError(message)
  82. if len(np.unique(axis)) != len(axis):
  83. raise AxisError("`axis` must contain only distinct elements")
  84. removed_shapes = new_shapes[:, axis]
  85. new_shapes = np.delete(new_shapes, axis, axis=1)
  86. # If arrays are broadcastable, shape elements that are 1 may be replaced
  87. # with a corresponding non-1 shape element. Assuming arrays are
  88. # broadcastable, that final shape element can be found with:
  89. new_shape = np.max(new_shapes, axis=0)
  90. # except in case of an empty array:
  91. new_shape *= new_shapes.all(axis=0)
  92. # Among all arrays, there can only be one unique non-1 shape element.
  93. # Therefore, if any non-1 shape element does not match what we found
  94. # above, the arrays must not be broadcastable after all.
  95. if np.any(~((new_shapes == 1) | (new_shapes == new_shape))):
  96. raise ValueError("Array shapes are incompatible for broadcasting.")
  97. if axis is not None:
  98. # Add back the shape elements that were ignored
  99. new_axis = axis - np.arange(len(axis))
  100. new_shapes = [tuple(np.insert(new_shape, new_axis, removed_shape))
  101. for removed_shape in removed_shapes]
  102. return new_shapes
  103. else:
  104. return tuple(new_shape)
  105. def _broadcast_array_shapes_remove_axis(arrays, axis=None):
  106. """
  107. Broadcast shapes of arrays, dropping specified axes
  108. Given a sequence of arrays `arrays` and an integer or tuple `axis`, find
  109. the shape of the broadcast result after consuming/dropping `axis`.
  110. In other words, return output shape of a typical hypothesis test on
  111. `arrays` vectorized along `axis`.
  112. Examples
  113. --------
  114. >>> import numpy as np
  115. >>> from scipy.stats._axis_nan_policy import _broadcast_array_shapes_remove_axis
  116. >>> a = np.zeros((5, 2, 1))
  117. >>> b = np.zeros((9, 3))
  118. >>> _broadcast_array_shapes_remove_axis((a, b), 1)
  119. (5, 3)
  120. """
  121. # Note that here, `axis=None` means do not consume/drop any axes - _not_
  122. # ravel arrays before broadcasting.
  123. shapes = [arr.shape for arr in arrays]
  124. return _broadcast_shapes_remove_axis(shapes, axis)
  125. def _broadcast_shapes_remove_axis(shapes, axis=None):
  126. """
  127. Broadcast shapes, dropping specified axes
  128. Same as _broadcast_array_shapes_remove_axis, but given a sequence
  129. of array shapes `shapes` instead of the arrays themselves.
  130. """
  131. shapes = _broadcast_shapes(shapes, axis)
  132. shape = shapes[0]
  133. if axis is not None:
  134. shape = np.delete(shape, axis)
  135. return tuple(shape)
  136. def _broadcast_concatenate(arrays, axis, paired=False, xp=None):
  137. """Concatenate arrays along an axis with broadcasting."""
  138. xp = array_namespace(*arrays) if xp is None else xp
  139. arrays = _broadcast_arrays(arrays, axis if not paired else None, xp=xp)
  140. res = xp.concat(arrays, axis=axis)
  141. return res
  142. def _remove_nans(samples, paired, xp=None):
  143. "Remove nans from paired or unpaired 1D samples"
  144. # potential optimization: don't copy arrays that don't contain nans
  145. xp = array_namespace(*samples)
  146. if not paired:
  147. return [sample[~xp.isnan(sample)] for sample in samples]
  148. # for paired samples, we need to remove the whole pair when any part
  149. # has a nan
  150. nans = xp.isnan(samples[0])
  151. for sample in samples[1:]:
  152. nans = nans | xp.isnan(sample)
  153. not_nans = ~nans
  154. return [sample[not_nans] for sample in samples]
  155. def _remove_sentinel(samples, paired, sentinel):
  156. "Remove sentinel values from paired or unpaired 1D samples"
  157. # could consolidate with `_remove_nans`, but it's not quite as simple as
  158. # passing `sentinel=np.nan` because `(np.nan == np.nan) is False`
  159. # potential optimization: don't copy arrays that don't contain sentinel
  160. if not paired:
  161. return [sample[sample != sentinel] for sample in samples]
  162. # for paired samples, we need to remove the whole pair when any part
  163. # has a nan
  164. sentinels = (samples[0] == sentinel)
  165. for sample in samples[1:]:
  166. sentinels = sentinels | (sample == sentinel)
  167. not_sentinels = ~sentinels
  168. return [sample[not_sentinels] for sample in samples]
  169. def _masked_arrays_2_sentinel_arrays(samples):
  170. # masked arrays in `samples` are converted to regular arrays, and values
  171. # corresponding with masked elements are replaced with a sentinel value
  172. # return without modifying arrays if none have a mask
  173. has_mask = False
  174. for sample in samples:
  175. mask = getattr(sample, 'mask', False)
  176. has_mask = has_mask or np.any(mask)
  177. if not has_mask:
  178. return samples, None # None means there is no sentinel value
  179. # Choose a sentinel value. We can't use `np.nan`, because sentinel (masked)
  180. # values are always omitted, but there are different nan policies.
  181. dtype = np.result_type(*samples)
  182. dtype = dtype if np.issubdtype(dtype, np.number) else np.float64
  183. for i in range(len(samples)):
  184. # Things get more complicated if the arrays are of different types.
  185. # We could have different sentinel values for each array, but
  186. # the purpose of this code is convenience, not efficiency.
  187. samples[i] = samples[i].astype(dtype, copy=False)
  188. inexact = np.issubdtype(dtype, np.inexact)
  189. info = np.finfo if inexact else np.iinfo
  190. max_possible, min_possible = info(dtype).max, info(dtype).min
  191. nextafter = np.nextafter if inexact else (lambda x, _: x - 1)
  192. sentinel = max_possible
  193. # For simplicity, min_possible/np.infs are not candidate sentinel values
  194. while sentinel > min_possible:
  195. for sample in samples:
  196. if np.any(sample == sentinel): # choose a new sentinel value
  197. sentinel = nextafter(sentinel, -np.inf)
  198. break
  199. else: # when sentinel value is OK, break the while loop
  200. break
  201. else:
  202. message = ("This function replaces masked elements with sentinel "
  203. "values, but the data contains all distinct values of this "
  204. "data type. Consider promoting the dtype to `np.float64`.")
  205. raise ValueError(message)
  206. # replace masked elements with sentinel value
  207. out_samples = []
  208. for sample in samples:
  209. mask = getattr(sample, 'mask', None)
  210. if mask is not None: # turn all masked arrays into sentinel arrays
  211. mask = np.broadcast_to(mask, sample.shape)
  212. sample = sample.data.copy() if np.any(mask) else sample.data
  213. sample = np.asarray(sample) # `sample.data` could be a memoryview?
  214. sample[mask] = sentinel
  215. out_samples.append(sample)
  216. return out_samples, sentinel
  217. def _check_empty_inputs(samples, axis, xp=None):
  218. """
  219. Check for empty sample; return appropriate output for a vectorized hypotest
  220. """
  221. xp = array_namespace(*samples) if xp is None else xp
  222. # if none of the samples are empty, we need to perform the test
  223. if not any(xp_size(sample) == 0 for sample in samples):
  224. return None
  225. # otherwise, the statistic and p-value will be either empty arrays or
  226. # arrays with NaNs. Produce the appropriate array and return it.
  227. output_shape = _broadcast_array_shapes_remove_axis(samples, axis)
  228. NaN = _get_nan(*samples)
  229. output = xp.full(output_shape, xp.nan, dtype=NaN.dtype)
  230. return output
  231. def _add_reduced_axes(res, reduced_axes, keepdims, xp=np):
  232. """
  233. Add reduced axes back to all the arrays in the result object
  234. if keepdims = True.
  235. """
  236. return ([xpx.expand_dims(output, axis=reduced_axes)
  237. if not isinstance(output, int) else output for output in res]
  238. if keepdims else res)
  239. # Standard docstring / signature entries for `axis`, `nan_policy`, `keepdims`
  240. _name = 'axis'
  241. _desc = (
  242. """If an int, the axis of the input along which to compute the statistic.
  243. The statistic of each axis-slice (e.g. row) of the input will appear in a
  244. corresponding element of the output.
  245. If ``None``, the input will be raveled before computing the statistic."""
  246. .split('\n'))
  247. def _get_axis_params(default_axis=0, _name=_name, _desc=_desc): # bind NOW
  248. _type = f"int or None, default: {default_axis}"
  249. _axis_parameter_doc = Parameter(_name, _type, _desc)
  250. _axis_parameter = inspect.Parameter(_name,
  251. inspect.Parameter.KEYWORD_ONLY,
  252. default=default_axis)
  253. return _axis_parameter_doc, _axis_parameter
  254. _name = 'nan_policy'
  255. _type = "{'propagate', 'omit', 'raise'}"
  256. _desc = (
  257. """Defines how to handle input NaNs.
  258. - ``propagate``: if a NaN is present in the axis slice (e.g. row) along
  259. which the statistic is computed, the corresponding entry of the output
  260. will be NaN.
  261. - ``omit``: NaNs will be omitted when performing the calculation.
  262. If insufficient data remains in the axis slice along which the
  263. statistic is computed, the corresponding entry of the output will be
  264. NaN.
  265. - ``raise``: if a NaN is present, a ``ValueError`` will be raised."""
  266. .split('\n'))
  267. _nan_policy_parameter_doc = Parameter(_name, _type, _desc)
  268. _nan_policy_parameter = inspect.Parameter(_name,
  269. inspect.Parameter.KEYWORD_ONLY,
  270. default='propagate')
  271. _name = 'keepdims'
  272. _type = "bool, default: False"
  273. _desc = (
  274. """If this is set to True, the axes which are reduced are left
  275. in the result as dimensions with size one. With this option,
  276. the result will broadcast correctly against the input array."""
  277. .split('\n'))
  278. _keepdims_parameter_doc = Parameter(_name, _type, _desc)
  279. _keepdims_parameter = inspect.Parameter(_name,
  280. inspect.Parameter.KEYWORD_ONLY,
  281. default=False)
  282. _standard_note_addition = (
  283. """\nBeginning in SciPy 1.9, ``np.matrix`` inputs (not recommended for new
  284. code) are converted to ``np.ndarray`` before the calculation is performed. In
  285. this case, the output will be a scalar or ``np.ndarray`` of appropriate shape
  286. rather than a 2D ``np.matrix``. Similarly, while masked elements of masked
  287. arrays are ignored, the output will be a scalar or ``np.ndarray`` rather than a
  288. masked array with ``mask=False``.""").split('\n')
  289. def _axis_nan_policy_factory(tuple_to_result, default_axis=0,
  290. n_samples=1, paired=False,
  291. result_to_tuple=None, too_small=0,
  292. n_outputs=2, kwd_samples=(), override=None):
  293. """Factory for a wrapper that adds axis/nan_policy params to a function.
  294. Parameters
  295. ----------
  296. tuple_to_result : callable
  297. Callable that returns an object of the type returned by the function
  298. being wrapped (e.g. the namedtuple or dataclass returned by a
  299. statistical test) provided the separate components (e.g. statistic,
  300. pvalue).
  301. default_axis : int, default: 0
  302. The default value of the axis argument. Standard is 0 except when
  303. backwards compatibility demands otherwise (e.g. `None`).
  304. n_samples : int or callable, default: 1
  305. The number of data samples accepted by the function
  306. (e.g. `mannwhitneyu`), a callable that accepts a dictionary of
  307. parameters passed into the function and returns the number of data
  308. samples (e.g. `wilcoxon`), or `None` to indicate an arbitrary number
  309. of samples (e.g. `kruskal`).
  310. paired : {False, True}
  311. Whether the function being wrapped treats the samples as paired (i.e.
  312. corresponding elements of each sample should be considered as different
  313. components of the same sample.)
  314. result_to_tuple : callable, optional
  315. Function that unpacks the results of the function being wrapped into
  316. a tuple. This is essentially the inverse of `tuple_to_result`. Default
  317. is `None`, which is appropriate for statistical tests that return a
  318. statistic, pvalue tuple (rather than, e.g., a non-iterable datalass).
  319. too_small : int or callable, default: 0
  320. The largest unnacceptably small sample for the function being wrapped.
  321. For example, some functions require samples of size two or more or they
  322. raise an error. This argument prevents the error from being raised when
  323. input is not 1D and instead places a NaN in the corresponding element
  324. of the result. If callable, it must accept a list of samples, axis,
  325. and a dictionary of keyword arguments passed to the wrapper function as
  326. arguments and return a bool indicating weather the samples passed are
  327. too small.
  328. n_outputs : int or callable, default: 2
  329. The number of outputs produced by the function given 1d sample(s). For
  330. example, hypothesis tests that return a namedtuple or result object
  331. with attributes ``statistic`` and ``pvalue`` use the default
  332. ``n_outputs=2``; summary statistics with scalar output use
  333. ``n_outputs=1``. Alternatively, may be a callable that accepts a
  334. dictionary of arguments passed into the wrapped function and returns
  335. the number of outputs corresponding with those arguments.
  336. kwd_samples : sequence, default: ()
  337. The names of keyword parameters that should be treated as samples. For
  338. example, `gmean` accepts as its first argument a sample `a` but
  339. also `weights` as a fourth, optional keyword argument. In this case, we
  340. use `n_samples=1` and kwd_samples=['weights'].
  341. override : dict, default: {'vectorization': False, 'nan_propagation': True}
  342. Pass a dictionary with ``'vectorization': True`` to ensure that the
  343. decorator overrides the function's behavior for multimensional input.
  344. Use ``'nan_propagation': False`` to ensure that the decorator does not
  345. override the function's behavior for ``nan_policy='propagate'``.
  346. """
  347. # Specify which existing behaviors the decorator must override
  348. temp = override or {}
  349. override = {'vectorization': False,
  350. 'nan_propagation': True}
  351. override.update(temp)
  352. if result_to_tuple is None:
  353. def result_to_tuple(res, _):
  354. return res
  355. if not callable(too_small):
  356. def is_too_small(samples, *ts_args, axis=-1, **ts_kwargs):
  357. for sample in samples:
  358. if sample.shape[axis] <= too_small:
  359. return True
  360. return False
  361. else:
  362. is_too_small = too_small
  363. def axis_nan_policy_decorator(hypotest_fun_in):
  364. @wraps(hypotest_fun_in)
  365. def axis_nan_policy_wrapper(*args, _no_deco=False, **kwds):
  366. if _no_deco: # for testing, decorator does nothing
  367. return hypotest_fun_in(*args, **kwds)
  368. # For now, skip the decorator entirely if using array API. In the future,
  369. # we'll probably want to use it for `keepdims`, `axis` tuples, etc.
  370. if len(args) == 0: # extract sample from `kwds` if there are no `args`
  371. used_kwd_samples = list(set(kwds).intersection(set(kwd_samples)))
  372. temp = used_kwd_samples[:1]
  373. else:
  374. temp = args[0]
  375. if is_lazy_array(temp):
  376. msg = ("Use of `nan_policy` and `keepdims` "
  377. "is incompatible with lazy arrays.")
  378. if 'nan_policy' in kwds or 'keepdims' in kwds:
  379. raise NotImplementedError(msg)
  380. return hypotest_fun_in(*args, **kwds)
  381. # We need to be flexible about whether position or keyword
  382. # arguments are used, but we need to make sure users don't pass
  383. # both for the same parameter. To complicate matters, some
  384. # functions accept samples with *args, and some functions already
  385. # accept `axis` and `nan_policy` as positional arguments.
  386. # The strategy is to make sure that there is no duplication
  387. # between `args` and `kwds`, combine the two into `kwds`, then
  388. # the samples, `nan_policy`, and `axis` from `kwds`, as they are
  389. # dealt with separately.
  390. # Check for intersection between positional and keyword args
  391. params = list(inspect.signature(hypotest_fun_in).parameters)
  392. if n_samples is None:
  393. # Give unique names to each positional sample argument
  394. # Note that *args can't be provided as a keyword argument
  395. params = [f"arg{i}" for i in range(len(args))] + params[1:]
  396. # raise if there are too many positional args
  397. maxarg = (np.inf if inspect.getfullargspec(hypotest_fun_in).varargs
  398. else len(inspect.getfullargspec(hypotest_fun_in).args))
  399. if len(args) > maxarg: # let the function raise the right error
  400. hypotest_fun_in(*args, **kwds)
  401. # raise if multiple values passed for same parameter
  402. d_args = dict(zip(params, args))
  403. intersection = set(d_args) & set(kwds)
  404. if intersection: # let the function raise the right error
  405. hypotest_fun_in(*args, **kwds)
  406. # Consolidate other positional and keyword args into `kwds`
  407. kwds.update(d_args)
  408. # rename avoids UnboundLocalError
  409. if callable(n_samples):
  410. # Future refactoring idea: no need for callable n_samples.
  411. # Just replace `n_samples` and `kwd_samples` with a single
  412. # list of the names of all samples, and treat all of them
  413. # as `kwd_samples` are treated below.
  414. n_samp = n_samples(kwds)
  415. else:
  416. n_samp = n_samples or len(args)
  417. # get the number of outputs
  418. n_out = n_outputs # rename to avoid UnboundLocalError
  419. if callable(n_out):
  420. n_out = n_out(kwds)
  421. # If necessary, rearrange function signature: accept other samples
  422. # as positional args right after the first n_samp args
  423. kwd_samp = [name for name in kwd_samples
  424. if kwds.get(name, None) is not None]
  425. n_kwd_samp = len(kwd_samp)
  426. if not kwd_samp:
  427. hypotest_fun_out = hypotest_fun_in
  428. else:
  429. def hypotest_fun_out(*samples, **kwds):
  430. new_kwds = dict(zip(kwd_samp, samples[n_samp:]))
  431. kwds.update(new_kwds)
  432. return hypotest_fun_in(*samples[:n_samp], **kwds)
  433. # Extract the things we need here
  434. try: # if something is missing
  435. samples = [kwds.pop(param) for param in (params[:n_samp] + kwd_samp)]
  436. xp = array_namespace(*samples)
  437. samples = xp_promote(*samples, xp=xp)
  438. samples = (samples,) if not isinstance(samples, tuple) else samples
  439. samples = [xpx.atleast_nd(sample, ndim=1) for sample in samples]
  440. except KeyError: # let the function raise the right error
  441. # might need to revisit this if required arg is not a "sample"
  442. hypotest_fun_in(*args, **kwds)
  443. vectorized = True if 'axis' in params else False
  444. vectorized = vectorized and not override['vectorization']
  445. axis = kwds.pop('axis', default_axis)
  446. nan_policy = kwds.pop('nan_policy', 'propagate')
  447. keepdims = kwds.pop("keepdims", False)
  448. del args # avoid the possibility of passing both `args` and `kwds`
  449. # convert masked arrays to regular arrays with sentinel values
  450. sentinel = None
  451. if is_numpy(xp):
  452. samples, sentinel = _masked_arrays_2_sentinel_arrays(samples)
  453. # standardize to always work along last axis
  454. reduced_axes = axis
  455. if axis is None:
  456. if samples:
  457. # when axis=None, take the maximum of all dimensions since
  458. # all the dimensions are reduced.
  459. n_dims = max([xp.asarray(sample).ndim for sample in samples])
  460. reduced_axes = tuple(range(n_dims))
  461. samples = [xp_ravel(sample) for sample in samples]
  462. else:
  463. # don't ignore any axes when broadcasting if paired
  464. samples = _broadcast_arrays(samples, axis=axis if not paired else None)
  465. axis = (axis,) if np.isscalar(axis) else axis
  466. n_axes = len(axis)
  467. # move all axes in `axis` to the end to be raveled
  468. samples = [xp.moveaxis(sample, axis, tuple(range(-len(axis), 0)))
  469. for sample in samples]
  470. shapes = [sample.shape for sample in samples]
  471. # New shape is unchanged for all axes _not_ in `axis`
  472. # At the end, we append the product of the shapes of the axes
  473. # in `axis`. Appending -1 doesn't work for zero-size arrays!
  474. new_shapes = [shape[:-n_axes] + (math.prod(shape[-n_axes:]),)
  475. for shape in shapes]
  476. samples = [xp.reshape(sample, new_shape)
  477. for sample, new_shape in zip(samples, new_shapes)]
  478. axis = -1 # work over the last axis
  479. NaN = _get_nan(*samples) if samples else xp.nan
  480. # if axis is not needed, just handle nan_policy and return
  481. ndims = np.array([sample.ndim for sample in samples]) # NumPy OK for ndims
  482. if np.all(ndims <= 1):
  483. # Addresses nan_policy == "raise"
  484. if nan_policy != 'propagate' or override['nan_propagation']:
  485. contains_nan = [_contains_nan(sample, nan_policy)
  486. for sample in samples]
  487. else:
  488. # Behave as though there are no NaNs (even if there are)
  489. contains_nan = [False] * len(samples)
  490. # Addresses nan_policy == "propagate"
  491. if any(contains_nan) and (nan_policy == 'propagate'
  492. and override['nan_propagation']):
  493. res = xp.full(n_out, xp.nan, dtype=NaN.dtype)
  494. res = _add_reduced_axes(res, reduced_axes, keepdims)
  495. return tuple_to_result(*res)
  496. # Addresses nan_policy == "omit"
  497. too_small_msg = too_small_1d_not_omit
  498. if any(contains_nan) and nan_policy == 'omit':
  499. # consider passing in contains_nan
  500. samples = _remove_nans(samples, paired)
  501. too_small_msg = too_small_1d_omit
  502. if sentinel:
  503. samples = _remove_sentinel(samples, paired, sentinel)
  504. if is_too_small(samples, kwds):
  505. warnings.warn(too_small_msg, SmallSampleWarning, stacklevel=2)
  506. res = xp.full(n_out, xp.nan, dtype=NaN.dtype)
  507. res = _add_reduced_axes(res, reduced_axes, keepdims)
  508. return tuple_to_result(*res)
  509. res = hypotest_fun_out(*samples, **kwds)
  510. res = result_to_tuple(res, n_out)
  511. res = _add_reduced_axes(res, reduced_axes, keepdims)
  512. return tuple_to_result(*res)
  513. # check for empty input
  514. empty_output = _check_empty_inputs(samples, axis, xp=xp)
  515. # only return empty output if zero sized input is too small.
  516. if (
  517. empty_output is not None
  518. and (is_too_small(samples, kwds) or xp_size(empty_output) == 0)
  519. ):
  520. if is_too_small(samples, kwds) and xp_size(empty_output) != 0:
  521. warnings.warn(too_small_nd_not_omit, SmallSampleWarning,
  522. stacklevel=2)
  523. res = [xp_copy(empty_output) for i in range(n_out)]
  524. res = _add_reduced_axes(res, reduced_axes, keepdims)
  525. return tuple_to_result(*res)
  526. if not is_numpy(xp) and 'nan_policy' in kwds:
  527. msg = ("Use of `nan_policy` is incompatible with multidimensional "
  528. "non-NumPy arrays.")
  529. raise NotImplementedError(msg)
  530. if not is_numpy(xp):
  531. res = hypotest_fun_out(*samples, axis=axis, **kwds)
  532. res = result_to_tuple(res, n_out)
  533. res = _add_reduced_axes(res, reduced_axes, keepdims, xp=xp)
  534. return tuple_to_result(*res)
  535. # otherwise, concatenate all samples along axis, remembering where
  536. # each separate sample begins
  537. lengths = np.array([sample.shape[axis] for sample in samples])
  538. split_indices = np.cumsum(lengths)
  539. x = _broadcast_concatenate(samples, axis, paired=paired)
  540. # Addresses nan_policy == "raise"
  541. if nan_policy != 'propagate' or override['nan_propagation']:
  542. contains_nan = _contains_nan(x, nan_policy)
  543. else:
  544. contains_nan = False # behave like there are no NaNs
  545. if vectorized and not contains_nan and not sentinel:
  546. res = hypotest_fun_out(*samples, axis=axis, **kwds)
  547. res = result_to_tuple(res, n_out)
  548. res = _add_reduced_axes(res, reduced_axes, keepdims)
  549. return tuple_to_result(*res)
  550. # Addresses nan_policy == "omit"
  551. if contains_nan and nan_policy == 'omit':
  552. def hypotest_fun(x):
  553. samples = np.split(x, split_indices)[:n_samp+n_kwd_samp]
  554. samples = _remove_nans(samples, paired)
  555. if sentinel:
  556. samples = _remove_sentinel(samples, paired, sentinel)
  557. if is_too_small(samples, kwds):
  558. warnings.warn(too_small_nd_omit, SmallSampleWarning,
  559. stacklevel=4)
  560. return np.full(n_out, NaN)
  561. return result_to_tuple(hypotest_fun_out(*samples, **kwds), n_out)
  562. # Addresses nan_policy == "propagate"
  563. elif (contains_nan and nan_policy == 'propagate'
  564. and override['nan_propagation']):
  565. def hypotest_fun(x):
  566. if np.isnan(x).any():
  567. return np.full(n_out, NaN)
  568. samples = np.split(x, split_indices)[:n_samp+n_kwd_samp]
  569. if sentinel:
  570. samples = _remove_sentinel(samples, paired, sentinel)
  571. if is_too_small(samples, kwds):
  572. return np.full(n_out, NaN)
  573. return result_to_tuple(hypotest_fun_out(*samples, **kwds), n_out)
  574. else:
  575. def hypotest_fun(x):
  576. samples = np.split(x, split_indices)[:n_samp+n_kwd_samp]
  577. if sentinel:
  578. samples = _remove_sentinel(samples, paired, sentinel)
  579. if is_too_small(samples, kwds):
  580. return np.full(n_out, NaN)
  581. return result_to_tuple(hypotest_fun_out(*samples, **kwds), n_out)
  582. x = np.moveaxis(x, axis, 0)
  583. res = np.apply_along_axis(hypotest_fun, axis=0, arr=x)
  584. res = _add_reduced_axes(res, reduced_axes, keepdims)
  585. return tuple_to_result(*res)
  586. _axis_parameter_doc, _axis_parameter = _get_axis_params(default_axis)
  587. doc = FunctionDoc(axis_nan_policy_wrapper)
  588. parameter_names = [param.name for param in doc['Parameters']]
  589. if 'axis' in parameter_names:
  590. doc['Parameters'][parameter_names.index('axis')] = (
  591. _axis_parameter_doc)
  592. else:
  593. doc['Parameters'].append(_axis_parameter_doc)
  594. if 'nan_policy' in parameter_names:
  595. doc['Parameters'][parameter_names.index('nan_policy')] = (
  596. _nan_policy_parameter_doc)
  597. else:
  598. doc['Parameters'].append(_nan_policy_parameter_doc)
  599. if 'keepdims' in parameter_names:
  600. doc['Parameters'][parameter_names.index('keepdims')] = (
  601. _keepdims_parameter_doc)
  602. else:
  603. doc['Parameters'].append(_keepdims_parameter_doc)
  604. doc['Notes'] += _standard_note_addition
  605. doc = str(doc).split("\n", 1)[1].lstrip(" \n") # remove signature
  606. axis_nan_policy_wrapper.__doc__ = str(doc)
  607. sig = inspect.signature(axis_nan_policy_wrapper)
  608. parameters = sig.parameters
  609. parameter_list = list(parameters.values())
  610. if 'axis' not in parameters:
  611. parameter_list.append(_axis_parameter)
  612. if 'nan_policy' not in parameters:
  613. parameter_list.append(_nan_policy_parameter)
  614. if 'keepdims' not in parameters:
  615. parameter_list.append(_keepdims_parameter)
  616. sig = sig.replace(parameters=parameter_list)
  617. axis_nan_policy_wrapper.__signature__ = sig
  618. return axis_nan_policy_wrapper
  619. return axis_nan_policy_decorator