| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036 |
- """Utility functions to use Python Array API compatible libraries.
- For the context about the Array API see:
- https://data-apis.org/array-api/latest/purpose_and_scope.html
- The SciPy use case of the Array API is described on the following page:
- https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy
- """
- import operator
- import dataclasses
- import functools
- import textwrap
- from collections.abc import Generator
- from contextlib import contextmanager
- from contextvars import ContextVar
- from types import ModuleType
- from typing import Any, Literal, TypeAlias
- from collections.abc import Iterable
- import numpy as np
- import numpy.typing as npt
- from scipy._lib.array_api_compat import (
- is_array_api_obj,
- is_lazy_array,
- is_numpy_array,
- is_cupy_array,
- is_torch_array,
- is_jax_array,
- is_dask_array,
- size as xp_size,
- numpy as np_compat,
- device as xp_device,
- is_numpy_namespace as is_numpy,
- is_cupy_namespace as is_cupy,
- is_torch_namespace as is_torch,
- is_jax_namespace as is_jax,
- is_dask_namespace as is_dask,
- is_array_api_strict_namespace as is_array_api_strict,
- )
- from scipy._lib.array_api_compat.common._helpers import _compat_module_name
- from scipy._lib.array_api_extra.testing import lazy_xp_function
- from scipy._lib._array_api_override import (
- array_namespace, SCIPY_ARRAY_API, SCIPY_DEVICE
- )
- from scipy._lib._docscrape import FunctionDoc
- from scipy._lib import array_api_extra as xpx
- __all__ = [
- '_asarray', 'array_namespace', 'assert_almost_equal', 'assert_array_almost_equal',
- 'default_xp', 'eager_warns', 'is_lazy_array', 'is_marray',
- 'is_array_api_strict', 'is_complex', 'is_cupy', 'is_jax', 'is_numpy', 'is_torch',
- 'np_compat', 'get_native_namespace_name',
- 'SCIPY_ARRAY_API', 'SCIPY_DEVICE', 'scipy_namespace_for',
- 'xp_assert_close', 'xp_assert_equal', 'xp_assert_less',
- 'xp_copy', 'xp_device', 'xp_ravel', 'xp_size',
- 'xp_unsupported_param_msg', 'xp_vector_norm', 'xp_capabilities',
- 'xp_result_type', 'xp_promote',
- 'make_xp_test_case', 'make_xp_pytest_marks', 'make_xp_pytest_param',
- ]
- Array: TypeAlias = Any # To be changed to a Protocol later (see array-api#589)
- ArrayLike: TypeAlias = Array | npt.ArrayLike
- def _check_finite(array: Array, xp: ModuleType) -> None:
- """Check for NaNs or Infs."""
- if not xp.all(xp.isfinite(array)):
- msg = "array must not contain infs or NaNs"
- raise ValueError(msg)
- def _asarray(
- array: ArrayLike,
- dtype: Any = None,
- order: Literal['K', 'A', 'C', 'F'] | None = None,
- copy: bool | None = None,
- *,
- xp: ModuleType | None = None,
- check_finite: bool = False,
- subok: bool = False,
- ) -> Array:
- """SciPy-specific replacement for `np.asarray` with `order`, `check_finite`, and
- `subok`.
- Memory layout parameter `order` is not exposed in the Array API standard.
- `order` is only enforced if the input array implementation
- is NumPy based, otherwise `order` is just silently ignored.
- `check_finite` is also not a keyword in the array API standard; included
- here for convenience rather than that having to be a separate function
- call inside SciPy functions.
- `subok` is included to allow this function to preserve the behaviour of
- `np.asanyarray` for NumPy based inputs.
- """
- if xp is None:
- xp = array_namespace(array)
- if is_numpy(xp):
- # Use NumPy API to support order
- if copy is True:
- array = np.array(array, order=order, dtype=dtype, subok=subok)
- elif subok:
- array = np.asanyarray(array, order=order, dtype=dtype)
- else:
- array = np.asarray(array, order=order, dtype=dtype)
- else:
- try:
- array = xp.asarray(array, dtype=dtype, copy=copy)
- except TypeError:
- coerced_xp = array_namespace(xp.asarray(3))
- array = coerced_xp.asarray(array, dtype=dtype, copy=copy)
- if check_finite:
- _check_finite(array, xp)
- return array
- def xp_copy(x: Array, *, xp: ModuleType | None = None) -> Array:
- """
- Copies an array.
- Parameters
- ----------
- x : array
- xp : array_namespace
- Returns
- -------
- copy : array
- Copied array
- Notes
- -----
- This copy function does not offer all the semantics of `np.copy`, i.e. the
- `subok` and `order` keywords are not used.
- """
- # Note: for older NumPy versions, `np.asarray` did not support the `copy` kwarg,
- # so this uses our other helper `_asarray`.
- if xp is None:
- xp = array_namespace(x)
- return _asarray(x, copy=True, xp=xp)
- def _xp_copy_to_numpy(x: Array) -> np.ndarray:
- """Copies a possibly on device array to a NumPy array.
- This function is intended only for converting alternative backend
- arrays to numpy arrays within test code, to make it easier for use
- of the alternative backend to be isolated only to the function being
- tested. `_xp_copy_to_numpy` should NEVER be used except in test code
- for the specific purpose mentioned above. In production code, attempts
- to copy device arrays to NumPy arrays should fail, or else functions
- may appear to be working on the GPU when they actually aren't.
- Parameters
- ----------
- x : array
- Returns
- -------
- ndarray
- """
- xp = array_namespace(x)
- if is_numpy(xp):
- return x.copy()
- if is_cupy(xp):
- return x.get()
- if is_torch(xp):
- return x.cpu().numpy()
- if is_array_api_strict(xp):
- # array api strict supports multiple devices, so need to
- # ensure x is on the cpu before copying to NumPy.
- return np.asarray(
- xp.asarray(x, device=xp.Device("CPU_DEVICE")), copy=True
- )
- # Fall back to np.asarray. This works for dask.array. It
- # currently works for jax.numpy, but hopefully JAX will make
- # the transfer guard workable enough for use in scipy tests, in
- # which case, JAX will have to be handled explicitly.
- # If new backends are added, they may require explicit handling as
- # well.
- return np.asarray(x, copy=True)
- _default_xp_ctxvar: ContextVar[ModuleType] = ContextVar("_default_xp")
- @contextmanager
- def default_xp(xp: ModuleType) -> Generator[None, None, None]:
- """In all ``xp_assert_*`` and ``assert_*`` function calls executed within this
- context manager, test by default that the array namespace is
- the provided across all arrays, unless one explicitly passes the ``xp=``
- parameter or ``check_namespace=False``.
- Without this context manager, the default value for `xp` is the namespace
- for the desired array (the second parameter of the tests).
- """
- token = _default_xp_ctxvar.set(xp)
- try:
- yield
- finally:
- _default_xp_ctxvar.reset(token)
- def eager_warns(warning_type, *, match=None, xp):
- """pytest.warns context manager if arrays of specified namespace are always eager.
- Otherwise, context manager that *ignores* specified warning.
- """
- import pytest
- from scipy._lib._util import ignore_warns
- if is_numpy(xp) or is_array_api_strict(xp) or is_cupy(xp):
- return pytest.warns(warning_type, match=match)
- return ignore_warns(warning_type, match='' if match is None else match)
- def _strict_check(actual, desired, xp, *,
- check_namespace=True, check_dtype=True, check_shape=True,
- check_0d=True):
- __tracebackhide__ = True # Hide traceback for py.test
- if xp is None:
- try:
- xp = _default_xp_ctxvar.get()
- except LookupError:
- xp = array_namespace(desired)
- if check_namespace:
- _assert_matching_namespace(actual, desired, xp)
- # only NumPy distinguishes between scalars and arrays; we do if check_0d=True.
- # do this first so we can then cast to array (and thus use the array API) below.
- if is_numpy(xp) and check_0d:
- _msg = ("Array-ness does not match:\n Actual: "
- f"{type(actual)}\n Desired: {type(desired)}")
- assert ((xp.isscalar(actual) and xp.isscalar(desired))
- or (not xp.isscalar(actual) and not xp.isscalar(desired))), _msg
- actual = xp.asarray(actual)
- desired = xp.asarray(desired)
- if check_dtype:
- _msg = f"dtypes do not match.\nActual: {actual.dtype}\nDesired: {desired.dtype}"
- assert actual.dtype == desired.dtype, _msg
- if check_shape:
- if is_dask(xp):
- actual.compute_chunk_sizes()
- desired.compute_chunk_sizes()
- _msg = f"Shapes do not match.\nActual: {actual.shape}\nDesired: {desired.shape}"
- assert actual.shape == desired.shape, _msg
- desired = xp.broadcast_to(desired, actual.shape)
- return actual, desired, xp
- def _assert_matching_namespace(actual, desired, xp):
- __tracebackhide__ = True # Hide traceback for py.test
- desired_arr_space = array_namespace(desired)
- _msg = ("Namespace of desired array does not match expectations "
- "set by the `default_xp` context manager or by the `xp`"
- "pytest fixture.\n"
- f"Desired array's space: {desired_arr_space.__name__}\n"
- f"Expected namespace: {xp.__name__}")
- assert desired_arr_space == xp, _msg
- actual_arr_space = array_namespace(actual)
- _msg = ("Namespace of actual and desired arrays do not match.\n"
- f"Actual: {actual_arr_space.__name__}\n"
- f"Desired: {xp.__name__}")
- assert actual_arr_space == xp, _msg
- def xp_assert_equal(actual, desired, *, check_namespace=True, check_dtype=True,
- check_shape=True, check_0d=True, err_msg='', xp=None):
- __tracebackhide__ = True # Hide traceback for py.test
- actual, desired, xp = _strict_check(
- actual, desired, xp, check_namespace=check_namespace,
- check_dtype=check_dtype, check_shape=check_shape,
- check_0d=check_0d
- )
- if is_cupy(xp):
- return xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
- elif is_torch(xp):
- # PyTorch recommends using `rtol=0, atol=0` like this
- # to test for exact equality
- err_msg = None if err_msg == '' else err_msg
- return xp.testing.assert_close(actual, desired, rtol=0, atol=0, equal_nan=True,
- check_dtype=False, msg=err_msg)
- # JAX uses `np.testing`
- return np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
- def xp_assert_close(actual, desired, *, rtol=None, atol=0, check_namespace=True,
- check_dtype=True, check_shape=True, check_0d=True,
- err_msg='', xp=None):
- __tracebackhide__ = True # Hide traceback for py.test
- actual, desired, xp = _strict_check(
- actual, desired, xp,
- check_namespace=check_namespace, check_dtype=check_dtype,
- check_shape=check_shape, check_0d=check_0d
- )
- floating = xp.isdtype(actual.dtype, ('real floating', 'complex floating'))
- if rtol is None and floating:
- # multiplier of 4 is used as for `np.float64` this puts the default `rtol`
- # roughly half way between sqrt(eps) and the default for
- # `numpy.testing.assert_allclose`, 1e-7
- rtol = xp.finfo(actual.dtype).eps**0.5 * 4
- elif rtol is None:
- rtol = 1e-7
- if is_cupy(xp):
- return xp.testing.assert_allclose(actual, desired, rtol=rtol,
- atol=atol, err_msg=err_msg)
- elif is_torch(xp):
- err_msg = None if err_msg == '' else err_msg
- return xp.testing.assert_close(actual, desired, rtol=rtol, atol=atol,
- equal_nan=True, check_dtype=False, msg=err_msg)
- # JAX uses `np.testing`
- return np.testing.assert_allclose(actual, desired, rtol=rtol,
- atol=atol, err_msg=err_msg)
- def xp_assert_close_nulp(actual, desired, *, nulp=1, check_namespace=True,
- check_dtype=True, check_shape=True, check_0d=True,
- err_msg='', xp=None):
- __tracebackhide__ = True # Hide traceback for py.test
- actual, desired, xp = _strict_check(
- actual, desired, xp,
- check_namespace=check_namespace, check_dtype=check_dtype,
- check_shape=check_shape, check_0d=check_0d
- )
- actual, desired = map(_xp_copy_to_numpy, (actual, desired))
- return np.testing.assert_array_almost_equal_nulp(actual, desired, nulp=nulp)
- def xp_assert_less(actual, desired, *, check_namespace=True, check_dtype=True,
- check_shape=True, check_0d=True, err_msg='', verbose=True, xp=None):
- __tracebackhide__ = True # Hide traceback for py.test
- actual, desired, xp = _strict_check(
- actual, desired, xp, check_namespace=check_namespace,
- check_dtype=check_dtype, check_shape=check_shape,
- check_0d=check_0d
- )
- if is_cupy(xp):
- return xp.testing.assert_array_less(actual, desired,
- err_msg=err_msg, verbose=verbose)
- elif is_torch(xp):
- if actual.device.type != 'cpu':
- actual = actual.cpu()
- if desired.device.type != 'cpu':
- desired = desired.cpu()
- # JAX uses `np.testing`
- return np.testing.assert_array_less(actual, desired,
- err_msg=err_msg, verbose=verbose)
- def assert_array_almost_equal(actual, desired, decimal=6, *args, **kwds):
- """Backwards compatible replacement. In new code, use xp_assert_close instead.
- """
- rtol, atol = 0, 1.5*10**(-decimal)
- return xp_assert_close(actual, desired,
- atol=atol, rtol=rtol, check_dtype=False, check_shape=False,
- *args, **kwds)
- def assert_almost_equal(actual, desired, decimal=7, *args, **kwds):
- """Backwards compatible replacement. In new code, use xp_assert_close instead.
- """
- rtol, atol = 0, 1.5*10**(-decimal)
- return xp_assert_close(actual, desired,
- atol=atol, rtol=rtol, check_dtype=False, check_shape=False,
- *args, **kwds)
- def xp_unsupported_param_msg(param: Any) -> str:
- return f'Providing {param!r} is only supported for numpy arrays.'
- def is_complex(x: Array, xp: ModuleType) -> bool:
- return xp.isdtype(x.dtype, 'complex floating')
- def get_native_namespace_name(xp: ModuleType) -> str:
- """Return name for native namespace (without array_api_compat prefix)."""
- name = xp.__name__
- return name.removeprefix(f"{_compat_module_name()}.")
- def scipy_namespace_for(xp: ModuleType) -> ModuleType | None:
- """Return the `scipy`-like namespace of a non-NumPy backend
- That is, return the namespace corresponding with backend `xp` that contains
- `scipy` sub-namespaces like `linalg` and `special`. If no such namespace
- exists, return ``None``. Useful for dispatching.
- """
- if is_cupy(xp):
- import cupyx # type: ignore[import-not-found,import-untyped]
- return cupyx.scipy
- if is_jax(xp):
- import jax # type: ignore[import-not-found]
- return jax.scipy
- if is_torch(xp):
- return xp
- return None
- # maybe use `scipy.linalg` if/when array API support is added
- def xp_vector_norm(x: Array, /, *,
- axis: int | tuple[int] | None = None,
- keepdims: bool = False,
- ord: int | float = 2,
- xp: ModuleType | None = None) -> Array:
- xp = array_namespace(x) if xp is None else xp
- if SCIPY_ARRAY_API:
- # check for optional `linalg` extension
- if hasattr(xp, 'linalg'):
- return xp.linalg.vector_norm(x, axis=axis, keepdims=keepdims, ord=ord)
- else:
- if ord != 2:
- raise ValueError(
- "only the Euclidean norm (`ord=2`) is currently supported in "
- "`xp_vector_norm` for backends not implementing the `linalg` "
- "extension."
- )
- # return (x @ x)**0.5
- # or to get the right behavior with nd, complex arrays
- return xp.sum(xp.conj(x) * x, axis=axis, keepdims=keepdims)**0.5
- else:
- # to maintain backwards compatibility
- return np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
- def xp_ravel(x: Array, /, *, xp: ModuleType | None = None) -> Array:
- # Equivalent of np.ravel written in terms of array API
- # Even though it's one line, it comes up so often that it's worth having
- # this function for readability
- xp = array_namespace(x) if xp is None else xp
- return xp.reshape(x, (-1,))
- def xp_swapaxes(a, axis1, axis2, xp=None):
- # Equivalent of np.swapaxes written in terms of array API
- xp = array_namespace(a) if xp is None else xp
- axes = list(range(a.ndim))
- axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
- a = xp.permute_dims(a, axes)
- return a
- # utility to find common dtype with option to force floating
- def xp_result_type(*args, force_floating=False, xp):
- """
- Returns the dtype that results from applying type promotion rules
- (see Array API Standard Type Promotion Rules) to the arguments. Augments
- standard `result_type` in a few ways:
- - There is a `force_floating` argument that ensures that the result type
- is floating point, even when all args are integer.
- - When a TypeError is raised (e.g. due to an unsupported promotion)
- and `force_floating=True`, we define a custom rule: use the result type
- of the default float and any other floats passed. See
- https://github.com/scipy/scipy/pull/22695/files#r1997905891
- for rationale.
- - This function accepts array-like iterables, which are immediately converted
- to the namespace's arrays before result type calculation. Consequently, the
- result dtype may be different when an argument is `1.` vs `[1.]`.
- Typically, this function will be called shortly after `array_namespace`
- on a subset of the arguments passed to `array_namespace`.
- """
- # prevent double conversion of iterable to array
- # avoid `np.iterable` for torch arrays due to pytorch/pytorch#143334
- # don't use `array_api_compat.is_array_api_obj` as it returns True for NumPy scalars
- args = [(_asarray(arg, subok=True, xp=xp) if is_torch_array(arg) or np.iterable(arg)
- else arg) for arg in args]
- args_not_none = [arg for arg in args if arg is not None]
- if force_floating:
- args_not_none.append(1.0)
- if is_numpy(xp) and xp.__version__ < '2.0':
- # Follow NEP 50 promotion rules anyway
- args_not_none = [arg.dtype if getattr(arg, 'size', 0) == 1 else arg
- for arg in args_not_none]
- return xp.result_type(*args_not_none)
- try: # follow library's preferred promotion rules
- return xp.result_type(*args_not_none)
- except TypeError: # mixed type promotion isn't defined
- if not force_floating:
- raise
- # use `result_type` of default floating point type and any floats present
- # This can be revisited, but right now, the only backends that get here
- # are array-api-strict (which is not for production use) and PyTorch
- # (due to data-apis/array-api-compat#279).
- float_args = []
- for arg in args_not_none:
- arg_array = xp.asarray(arg) if np.isscalar(arg) else arg
- dtype = getattr(arg_array, 'dtype', arg)
- if xp.isdtype(dtype, ('real floating', 'complex floating')):
- float_args.append(arg)
- return xp.result_type(*float_args, xp_default_dtype(xp))
- def xp_promote(*args, broadcast=False, force_floating=False, xp):
- """
- Promotes elements of *args to result dtype, ignoring `None`s.
- Includes options for forcing promotion to floating point and
- broadcasting the arrays, again ignoring `None`s.
- Type promotion rules follow `xp_result_type` instead of `xp.result_type`.
- Typically, this function will be called shortly after `array_namespace`
- on a subset of the arguments passed to `array_namespace`.
- This function accepts array-like iterables, which are immediately converted
- to the namespace's arrays before result type calculation. Consequently, the
- result dtype may be different when an argument is `1.` vs `[1.]`.
- See Also
- --------
- xp_result_type
- """
- if not args:
- return args
- # prevent double conversion of iterable to array
- # avoid `np.iterable` for torch arrays due to pytorch/pytorch#143334
- # don't use `array_api_compat.is_array_api_obj` as it returns True for NumPy scalars
- args = [(_asarray(arg, subok=True, xp=xp) if is_torch_array(arg) or np.iterable(arg)
- else arg) for arg in args]
- dtype = xp_result_type(*args, force_floating=force_floating, xp=xp)
- args = [(_asarray(arg, dtype=dtype, subok=True, xp=xp) if arg is not None else arg)
- for arg in args]
- if not broadcast:
- return args[0] if len(args)==1 else tuple(args)
- args_not_none = [arg for arg in args if arg is not None]
- # determine result shape
- shapes = {arg.shape for arg in args_not_none}
- try:
- shape = (np.broadcast_shapes(*shapes) if len(shapes) != 1
- else args_not_none[0].shape)
- except ValueError as e:
- message = "Array shapes are incompatible for broadcasting."
- raise ValueError(message) from e
- out = []
- for arg in args:
- if arg is None:
- out.append(arg)
- continue
- # broadcast only if needed
- # Even if two arguments need broadcasting, this is faster than
- # `broadcast_arrays`, especially since we've already determined `shape`
- if arg.shape != shape:
- kwargs = {'subok': True} if is_numpy(xp) else {}
- arg = xp.broadcast_to(arg, shape, **kwargs)
- # This is much faster than xp.astype(arg, dtype, copy=False)
- if arg.dtype != dtype:
- arg = xp.astype(arg, dtype)
- out.append(arg)
- return out[0] if len(out)==1 else tuple(out)
- def xp_float_to_complex(arr: Array, xp: ModuleType | None = None) -> Array:
- xp = array_namespace(arr) if xp is None else xp
- arr_dtype = arr.dtype
- # The standard float dtypes are float32 and float64.
- # Convert float32 to complex64,
- # and float64 (and non-standard real dtypes) to complex128
- if xp.isdtype(arr_dtype, xp.float32):
- arr = xp.astype(arr, xp.complex64)
- elif xp.isdtype(arr_dtype, 'real floating'):
- arr = xp.astype(arr, xp.complex128)
- return arr
- def xp_default_dtype(xp):
- """Query the namespace-dependent default floating-point dtype.
- """
- if is_torch(xp):
- # historically, we allow pytorch to keep its default of float32
- return xp.get_default_dtype()
- else:
- # we default to float64
- return xp.float64
- ### MArray Helpers ###
- def xp_result_device(*args):
- """Return the device of an array in `args`, for the purpose of
- input-output device propagation.
- If there are multiple devices, return an arbitrary one.
- If there are no arrays, return None (this typically happens only on NumPy).
- """
- for arg in args:
- # Do not do a duck-type test for the .device attribute, as many backends today
- # don't have it yet. See workarouunds in array_api_compat.device().
- if is_array_api_obj(arg):
- return xp_device(arg)
- return None
- # np.r_ replacement
- def concat_1d(xp: ModuleType | None, *arrays: Iterable[ArrayLike]) -> Array:
- """A replacement for `np.r_` as `xp.concat` does not accept python scalars
- or 0-D arrays.
- """
- arys = [xpx.atleast_nd(xp.asarray(a), ndim=1, xp=xp) for a in arrays]
- return xp.concat(arys)
- def is_marray(xp):
- """Returns True if `xp` is an MArray namespace; False otherwise."""
- return "marray" in xp.__name__
- def _length_nonmasked(x, axis, keepdims=False, xp=None):
- xp = array_namespace(x) if xp is None else xp
- if is_marray(xp):
- if np.iterable(axis):
- message = '`axis` must be an integer or None for use with `MArray`.'
- raise NotImplementedError(message)
- return xp.astype(xp.count(x, axis=axis, keepdims=keepdims), x.dtype)
- return (xp_size(x) if axis is None else
- # compact way to deal with axis tuples or ints
- int(np.prod(np.asarray(x.shape)[np.asarray(axis)])))
- def _share_masks(*args, xp):
- if is_marray(xp):
- mask = functools.reduce(operator.or_, (arg.mask for arg in args))
- args = [xp.asarray(arg.data, mask=mask) for arg in args]
- return args[0] if len(args) == 1 else args
- ### End MArray Helpers ###
- @dataclasses.dataclass(repr=False)
- class _XPSphinxCapability:
- cpu: bool | None # None if not applicable
- gpu: bool | None
- warnings: list[str] = dataclasses.field(default_factory=list)
- def _render(self, value):
- if value is None:
- return "n/a"
- if not value:
- return "⛔"
- if self.warnings:
- res = "⚠️ " + '; '.join(self.warnings)
- assert len(res) <= 20, "Warnings too long"
- return res
- return "✅"
- def __str__(self):
- cpu = self._render(self.cpu)
- gpu = self._render(self.gpu)
- return f"{cpu:20} {gpu:20}"
- def _make_sphinx_capabilities(
- # lists of tuples [(module name, reason), ...]
- skip_backends=(), xfail_backends=(),
- # @pytest.mark.skip/xfail_xp_backends kwargs
- cpu_only=False, np_only=False, out_of_scope=False, exceptions=(),
- # xpx.lazy_xp_backends kwargs
- allow_dask_compute=False, jax_jit=True,
- # list of tuples [(module name, reason), ...]
- warnings = (),
- # unused in documentation
- reason=None,
- ):
- if out_of_scope:
- return {"out_of_scope": True}
- exceptions = set(exceptions)
- # Default capabilities
- capabilities = {
- "numpy": _XPSphinxCapability(cpu=True, gpu=None),
- "array_api_strict": _XPSphinxCapability(cpu=True, gpu=None),
- "cupy": _XPSphinxCapability(cpu=None, gpu=True),
- "torch": _XPSphinxCapability(cpu=True, gpu=True),
- "jax.numpy": _XPSphinxCapability(cpu=True, gpu=True,
- warnings=[] if jax_jit else ["no JIT"]),
- # Note: Dask+CuPy is currently untested and unsupported
- "dask.array": _XPSphinxCapability(cpu=True, gpu=None,
- warnings=["computes graph"] if allow_dask_compute else []),
- }
- # documentation doesn't display the reason
- for module, _ in list(skip_backends) + list(xfail_backends):
- backend = capabilities[module]
- if backend.cpu is not None:
- backend.cpu = False
- if backend.gpu is not None:
- backend.gpu = False
- for module, backend in capabilities.items():
- if np_only and module not in exceptions | {"numpy"}:
- if backend.cpu is not None:
- backend.cpu = False
- if backend.gpu is not None:
- backend.gpu = False
- elif cpu_only and module not in exceptions and backend.gpu is not None:
- backend.gpu = False
- for module, warning in warnings:
- backend = capabilities[module]
- backend.warnings.append(warning)
- return capabilities
- def _make_capabilities_note(fun_name, capabilities, extra_note=None):
- if "out_of_scope" in capabilities:
- # It will be better to link to a section of the dev-arrayapi docs
- # that explains what is and isn't in-scope, but such a section
- # doesn't exist yet. Using :ref:`dev-arrayapi` as a placeholder.
- note = f"""
- **Array API Standard Support**
- `{fun_name}` is not in-scope for support of Python Array API Standard compatible
- backends other than NumPy.
- See :ref:`dev-arrayapi` for more information.
- """
- return textwrap.dedent(note)
- # Note: deliberately not documenting array-api-strict
- note = f"""
- **Array API Standard Support**
- `{fun_name}` has experimental support for Python Array API Standard compatible
- backends in addition to NumPy. Please consider testing these features
- by setting an environment variable ``SCIPY_ARRAY_API=1`` and providing
- CuPy, PyTorch, JAX, or Dask arrays as array arguments. The following
- combinations of backend and device (or other capability) are supported.
- ==================== ==================== ====================
- Library CPU GPU
- ==================== ==================== ====================
- NumPy {capabilities['numpy'] }
- CuPy {capabilities['cupy'] }
- PyTorch {capabilities['torch'] }
- JAX {capabilities['jax.numpy'] }
- Dask {capabilities['dask.array'] }
- ==================== ==================== ====================
- """ + (extra_note or "") + " See :ref:`dev-arrayapi` for more information."
- return textwrap.dedent(note)
- def xp_capabilities(
- *,
- # Alternative capabilities table.
- # Used only for testing this decorator.
- capabilities_table=None,
- # Generate pytest.mark.skip/xfail_xp_backends.
- # See documentation in conftest.py.
- # lists of tuples [(module name, reason), ...]
- skip_backends=(), xfail_backends=(),
- cpu_only=False, np_only=False, reason=None,
- out_of_scope=False, exceptions=(),
- # lists of tuples [(module name, reason), ...]
- warnings=(),
- # xpx.testing.lazy_xp_function kwargs.
- # Refer to array-api-extra documentation.
- allow_dask_compute=False, jax_jit=True,
- # Extra note to inject into the docstring
- extra_note=None,
- ):
- """Decorator for a function that states its support among various
- Array API compatible backends.
- This decorator has two effects:
- 1. It allows tagging tests with ``@make_xp_test_case`` or
- ``make_xp_pytest_param`` (see below) to automatically generate
- SKIP/XFAIL markers and perform additional backend-specific
- testing, such as extra validation for Dask and JAX;
- 2. It automatically adds a note to the function's docstring, containing
- a table matching what has been tested.
- See Also
- --------
- make_xp_test_case
- make_xp_pytest_param
- array_api_extra.testing.lazy_xp_function
- """
- capabilities_table = (xp_capabilities_table if capabilities_table is None
- else capabilities_table)
- if out_of_scope:
- np_only = True
- capabilities = dict(
- skip_backends=skip_backends,
- xfail_backends=xfail_backends,
- cpu_only=cpu_only,
- np_only=np_only,
- out_of_scope=out_of_scope,
- reason=reason,
- exceptions=exceptions,
- allow_dask_compute=allow_dask_compute,
- jax_jit=jax_jit,
- warnings=warnings,
- )
- sphinx_capabilities = _make_sphinx_capabilities(**capabilities)
- def decorator(f):
- # Don't use a wrapper, as in some cases @xp_capabilities is
- # applied to a ufunc
- capabilities_table[f] = capabilities
- note = _make_capabilities_note(f.__name__, sphinx_capabilities, extra_note)
- doc = FunctionDoc(f)
- doc['Notes'].append(note)
- doc = str(doc).split("\n", 1)[1].lstrip(" \n") # remove signature
- try:
- f.__doc__ = doc
- except AttributeError:
- # Can't update __doc__ on ufuncs if SciPy
- # was compiled against NumPy < 2.2.
- pass
- return f
- return decorator
- def make_xp_test_case(*funcs, capabilities_table=None):
- capabilities_table = (xp_capabilities_table if capabilities_table is None
- else capabilities_table)
- """Generate pytest decorator for a test function that tests functionality
- of one or more Array API compatible functions.
- Read the parameters of the ``@xp_capabilities`` decorator applied to the
- listed functions and:
- - Generate the ``@pytest.mark.skip_xp_backends`` and
- ``@pytest.mark.xfail_xp_backends`` decorators
- for the decorated test function
- - Tag the function with `xpx.testing.lazy_xp_function`
- Example::
- @make_xp_test_case(f1)
- def test_f1(xp):
- ...
- @make_xp_test_case(f2)
- def test_f2(xp):
- ...
- @make_xp_test_case(f1, f2)
- def test_f1_and_f2(xp):
- ...
- The above is equivalent to::
- @pytest.mark.skip_xp_backends(...)
- @pytest.mark.skip_xp_backends(...)
- @pytest.mark.xfail_xp_backends(...)
- @pytest.mark.xfail_xp_backends(...)
- def test_f1(xp):
- ...
- etc., where the arguments of ``skip_xp_backends`` and ``xfail_xp_backends`` are
- determined by the ``@xp_capabilities`` decorator applied to the functions.
- See Also
- --------
- xp_capabilities
- make_xp_pytest_marks
- make_xp_pytest_param
- array_api_extra.testing.lazy_xp_function
- """
- marks = make_xp_pytest_marks(*funcs, capabilities_table=capabilities_table)
- return lambda func: functools.reduce(lambda f, g: g(f), marks, func)
- def make_xp_pytest_param(func, *args, capabilities_table=None):
- """Variant of ``make_xp_test_case`` that returns a pytest.param for a function,
- with all necessary skip_xp_backends and xfail_xp_backends marks applied::
- @pytest.mark.parametrize(
- "func", [make_xp_pytest_param(f1), make_xp_pytest_param(f2)]
- )
- def test(func, xp):
- ...
- The above is equivalent to::
- @pytest.mark.parametrize(
- "func", [
- pytest.param(f1, marks=[
- pytest.mark.skip_xp_backends(...),
- pytest.mark.xfail_xp_backends(...), ...]),
- pytest.param(f2, marks=[
- pytest.mark.skip_xp_backends(...),
- pytest.mark.xfail_xp_backends(...), ...]),
- )
- def test(func, xp):
- ...
- Parameters
- ----------
- func : Callable
- Function to be tested. It must be decorated with ``@xp_capabilities``.
- *args : Any, optional
- Extra pytest parameters for the use case, e.g.::
- @pytest.mark.parametrize("func,verb", [
- make_xp_pytest_param(f1, "hello"),
- make_xp_pytest_param(f2, "world")])
- def test(func, verb, xp):
- # iterates on (func=f1, verb="hello")
- # and (func=f2, verb="world")
- See Also
- --------
- xp_capabilities
- make_xp_test_case
- make_xp_pytest_marks
- array_api_extra.testing.lazy_xp_function
- """
- import pytest
- marks = make_xp_pytest_marks(func, capabilities_table=capabilities_table)
- return pytest.param(func, *args, marks=marks, id=func.__name__)
- def make_xp_pytest_marks(*funcs, capabilities_table=None):
- """Variant of ``make_xp_test_case`` that returns a list of pytest marks,
- which can be used with the module-level `pytestmark = ...` variable::
- pytestmark = make_xp_pytest_marks(f1, f2)
- def test(xp):
- ...
- In this example, the whole test module is dedicated to testing `f1` or `f2`,
- and the two functions have the same capabilities, so it's unnecessary to
- cherry-pick which test tests which function.
- The above is equivalent to::
- pytestmark = [
- pytest.mark.skip_xp_backends(...),
- pytest.mark.xfail_xp_backends(...), ...]),
- ]
- def test(xp):
- ...
-
- See Also
- --------
- xp_capabilities
- make_xp_test_case
- make_xp_pytest_param
- array_api_extra.testing.lazy_xp_function
- """
- capabilities_table = (xp_capabilities_table if capabilities_table is None
- else capabilities_table)
- import pytest
- marks = []
- for func in funcs:
- capabilities = capabilities_table[func]
- exceptions = capabilities['exceptions']
- reason = capabilities['reason']
- if capabilities['cpu_only']:
- marks.append(pytest.mark.skip_xp_backends(
- cpu_only=True, exceptions=exceptions, reason=reason))
- if capabilities['np_only']:
- marks.append(pytest.mark.skip_xp_backends(
- np_only=True, exceptions=exceptions, reason=reason))
- for mod_name, reason in capabilities['skip_backends']:
- marks.append(pytest.mark.skip_xp_backends(mod_name, reason=reason))
- for mod_name, reason in capabilities['xfail_backends']:
- marks.append(pytest.mark.xfail_xp_backends(mod_name, reason=reason))
- lazy_kwargs = {k: capabilities[k]
- for k in ('allow_dask_compute', 'jax_jit')}
- lazy_xp_function(func, **lazy_kwargs)
- return marks
- # Is it OK to have a dictionary that is mutated (once upon import) in many places?
- xp_capabilities_table = {} # type: ignore[var-annotated]
- def xp_device_type(a: Array) -> Literal["cpu", "cuda", None]:
- if is_numpy_array(a):
- return "cpu"
- if is_cupy_array(a):
- return "cuda"
- if is_torch_array(a):
- # TODO this can return other backends e.g. tpu but they're unsupported in scipy
- return a.device.type
- if is_jax_array(a):
- # TODO this can return other backends e.g. tpu but they're unsupported in scipy
- return "cuda" if (p := a.device.platform) == "gpu" else p
- if is_dask_array(a):
- return xp_device_type(a._meta)
- # array-api-strict is a stand-in for unknown libraries; don't special-case it
- return None
|