_array_api.py 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036
  1. """Utility functions to use Python Array API compatible libraries.
  2. For the context about the Array API see:
  3. https://data-apis.org/array-api/latest/purpose_and_scope.html
  4. The SciPy use case of the Array API is described on the following page:
  5. https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy
  6. """
  7. import operator
  8. import dataclasses
  9. import functools
  10. import textwrap
  11. from collections.abc import Generator
  12. from contextlib import contextmanager
  13. from contextvars import ContextVar
  14. from types import ModuleType
  15. from typing import Any, Literal, TypeAlias
  16. from collections.abc import Iterable
  17. import numpy as np
  18. import numpy.typing as npt
  19. from scipy._lib.array_api_compat import (
  20. is_array_api_obj,
  21. is_lazy_array,
  22. is_numpy_array,
  23. is_cupy_array,
  24. is_torch_array,
  25. is_jax_array,
  26. is_dask_array,
  27. size as xp_size,
  28. numpy as np_compat,
  29. device as xp_device,
  30. is_numpy_namespace as is_numpy,
  31. is_cupy_namespace as is_cupy,
  32. is_torch_namespace as is_torch,
  33. is_jax_namespace as is_jax,
  34. is_dask_namespace as is_dask,
  35. is_array_api_strict_namespace as is_array_api_strict,
  36. )
  37. from scipy._lib.array_api_compat.common._helpers import _compat_module_name
  38. from scipy._lib.array_api_extra.testing import lazy_xp_function
  39. from scipy._lib._array_api_override import (
  40. array_namespace, SCIPY_ARRAY_API, SCIPY_DEVICE
  41. )
  42. from scipy._lib._docscrape import FunctionDoc
  43. from scipy._lib import array_api_extra as xpx
  44. __all__ = [
  45. '_asarray', 'array_namespace', 'assert_almost_equal', 'assert_array_almost_equal',
  46. 'default_xp', 'eager_warns', 'is_lazy_array', 'is_marray',
  47. 'is_array_api_strict', 'is_complex', 'is_cupy', 'is_jax', 'is_numpy', 'is_torch',
  48. 'np_compat', 'get_native_namespace_name',
  49. 'SCIPY_ARRAY_API', 'SCIPY_DEVICE', 'scipy_namespace_for',
  50. 'xp_assert_close', 'xp_assert_equal', 'xp_assert_less',
  51. 'xp_copy', 'xp_device', 'xp_ravel', 'xp_size',
  52. 'xp_unsupported_param_msg', 'xp_vector_norm', 'xp_capabilities',
  53. 'xp_result_type', 'xp_promote',
  54. 'make_xp_test_case', 'make_xp_pytest_marks', 'make_xp_pytest_param',
  55. ]
  56. Array: TypeAlias = Any # To be changed to a Protocol later (see array-api#589)
  57. ArrayLike: TypeAlias = Array | npt.ArrayLike
  58. def _check_finite(array: Array, xp: ModuleType) -> None:
  59. """Check for NaNs or Infs."""
  60. if not xp.all(xp.isfinite(array)):
  61. msg = "array must not contain infs or NaNs"
  62. raise ValueError(msg)
  63. def _asarray(
  64. array: ArrayLike,
  65. dtype: Any = None,
  66. order: Literal['K', 'A', 'C', 'F'] | None = None,
  67. copy: bool | None = None,
  68. *,
  69. xp: ModuleType | None = None,
  70. check_finite: bool = False,
  71. subok: bool = False,
  72. ) -> Array:
  73. """SciPy-specific replacement for `np.asarray` with `order`, `check_finite`, and
  74. `subok`.
  75. Memory layout parameter `order` is not exposed in the Array API standard.
  76. `order` is only enforced if the input array implementation
  77. is NumPy based, otherwise `order` is just silently ignored.
  78. `check_finite` is also not a keyword in the array API standard; included
  79. here for convenience rather than that having to be a separate function
  80. call inside SciPy functions.
  81. `subok` is included to allow this function to preserve the behaviour of
  82. `np.asanyarray` for NumPy based inputs.
  83. """
  84. if xp is None:
  85. xp = array_namespace(array)
  86. if is_numpy(xp):
  87. # Use NumPy API to support order
  88. if copy is True:
  89. array = np.array(array, order=order, dtype=dtype, subok=subok)
  90. elif subok:
  91. array = np.asanyarray(array, order=order, dtype=dtype)
  92. else:
  93. array = np.asarray(array, order=order, dtype=dtype)
  94. else:
  95. try:
  96. array = xp.asarray(array, dtype=dtype, copy=copy)
  97. except TypeError:
  98. coerced_xp = array_namespace(xp.asarray(3))
  99. array = coerced_xp.asarray(array, dtype=dtype, copy=copy)
  100. if check_finite:
  101. _check_finite(array, xp)
  102. return array
  103. def xp_copy(x: Array, *, xp: ModuleType | None = None) -> Array:
  104. """
  105. Copies an array.
  106. Parameters
  107. ----------
  108. x : array
  109. xp : array_namespace
  110. Returns
  111. -------
  112. copy : array
  113. Copied array
  114. Notes
  115. -----
  116. This copy function does not offer all the semantics of `np.copy`, i.e. the
  117. `subok` and `order` keywords are not used.
  118. """
  119. # Note: for older NumPy versions, `np.asarray` did not support the `copy` kwarg,
  120. # so this uses our other helper `_asarray`.
  121. if xp is None:
  122. xp = array_namespace(x)
  123. return _asarray(x, copy=True, xp=xp)
  124. def _xp_copy_to_numpy(x: Array) -> np.ndarray:
  125. """Copies a possibly on device array to a NumPy array.
  126. This function is intended only for converting alternative backend
  127. arrays to numpy arrays within test code, to make it easier for use
  128. of the alternative backend to be isolated only to the function being
  129. tested. `_xp_copy_to_numpy` should NEVER be used except in test code
  130. for the specific purpose mentioned above. In production code, attempts
  131. to copy device arrays to NumPy arrays should fail, or else functions
  132. may appear to be working on the GPU when they actually aren't.
  133. Parameters
  134. ----------
  135. x : array
  136. Returns
  137. -------
  138. ndarray
  139. """
  140. xp = array_namespace(x)
  141. if is_numpy(xp):
  142. return x.copy()
  143. if is_cupy(xp):
  144. return x.get()
  145. if is_torch(xp):
  146. return x.cpu().numpy()
  147. if is_array_api_strict(xp):
  148. # array api strict supports multiple devices, so need to
  149. # ensure x is on the cpu before copying to NumPy.
  150. return np.asarray(
  151. xp.asarray(x, device=xp.Device("CPU_DEVICE")), copy=True
  152. )
  153. # Fall back to np.asarray. This works for dask.array. It
  154. # currently works for jax.numpy, but hopefully JAX will make
  155. # the transfer guard workable enough for use in scipy tests, in
  156. # which case, JAX will have to be handled explicitly.
  157. # If new backends are added, they may require explicit handling as
  158. # well.
  159. return np.asarray(x, copy=True)
  160. _default_xp_ctxvar: ContextVar[ModuleType] = ContextVar("_default_xp")
  161. @contextmanager
  162. def default_xp(xp: ModuleType) -> Generator[None, None, None]:
  163. """In all ``xp_assert_*`` and ``assert_*`` function calls executed within this
  164. context manager, test by default that the array namespace is
  165. the provided across all arrays, unless one explicitly passes the ``xp=``
  166. parameter or ``check_namespace=False``.
  167. Without this context manager, the default value for `xp` is the namespace
  168. for the desired array (the second parameter of the tests).
  169. """
  170. token = _default_xp_ctxvar.set(xp)
  171. try:
  172. yield
  173. finally:
  174. _default_xp_ctxvar.reset(token)
  175. def eager_warns(warning_type, *, match=None, xp):
  176. """pytest.warns context manager if arrays of specified namespace are always eager.
  177. Otherwise, context manager that *ignores* specified warning.
  178. """
  179. import pytest
  180. from scipy._lib._util import ignore_warns
  181. if is_numpy(xp) or is_array_api_strict(xp) or is_cupy(xp):
  182. return pytest.warns(warning_type, match=match)
  183. return ignore_warns(warning_type, match='' if match is None else match)
  184. def _strict_check(actual, desired, xp, *,
  185. check_namespace=True, check_dtype=True, check_shape=True,
  186. check_0d=True):
  187. __tracebackhide__ = True # Hide traceback for py.test
  188. if xp is None:
  189. try:
  190. xp = _default_xp_ctxvar.get()
  191. except LookupError:
  192. xp = array_namespace(desired)
  193. if check_namespace:
  194. _assert_matching_namespace(actual, desired, xp)
  195. # only NumPy distinguishes between scalars and arrays; we do if check_0d=True.
  196. # do this first so we can then cast to array (and thus use the array API) below.
  197. if is_numpy(xp) and check_0d:
  198. _msg = ("Array-ness does not match:\n Actual: "
  199. f"{type(actual)}\n Desired: {type(desired)}")
  200. assert ((xp.isscalar(actual) and xp.isscalar(desired))
  201. or (not xp.isscalar(actual) and not xp.isscalar(desired))), _msg
  202. actual = xp.asarray(actual)
  203. desired = xp.asarray(desired)
  204. if check_dtype:
  205. _msg = f"dtypes do not match.\nActual: {actual.dtype}\nDesired: {desired.dtype}"
  206. assert actual.dtype == desired.dtype, _msg
  207. if check_shape:
  208. if is_dask(xp):
  209. actual.compute_chunk_sizes()
  210. desired.compute_chunk_sizes()
  211. _msg = f"Shapes do not match.\nActual: {actual.shape}\nDesired: {desired.shape}"
  212. assert actual.shape == desired.shape, _msg
  213. desired = xp.broadcast_to(desired, actual.shape)
  214. return actual, desired, xp
  215. def _assert_matching_namespace(actual, desired, xp):
  216. __tracebackhide__ = True # Hide traceback for py.test
  217. desired_arr_space = array_namespace(desired)
  218. _msg = ("Namespace of desired array does not match expectations "
  219. "set by the `default_xp` context manager or by the `xp`"
  220. "pytest fixture.\n"
  221. f"Desired array's space: {desired_arr_space.__name__}\n"
  222. f"Expected namespace: {xp.__name__}")
  223. assert desired_arr_space == xp, _msg
  224. actual_arr_space = array_namespace(actual)
  225. _msg = ("Namespace of actual and desired arrays do not match.\n"
  226. f"Actual: {actual_arr_space.__name__}\n"
  227. f"Desired: {xp.__name__}")
  228. assert actual_arr_space == xp, _msg
  229. def xp_assert_equal(actual, desired, *, check_namespace=True, check_dtype=True,
  230. check_shape=True, check_0d=True, err_msg='', xp=None):
  231. __tracebackhide__ = True # Hide traceback for py.test
  232. actual, desired, xp = _strict_check(
  233. actual, desired, xp, check_namespace=check_namespace,
  234. check_dtype=check_dtype, check_shape=check_shape,
  235. check_0d=check_0d
  236. )
  237. if is_cupy(xp):
  238. return xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
  239. elif is_torch(xp):
  240. # PyTorch recommends using `rtol=0, atol=0` like this
  241. # to test for exact equality
  242. err_msg = None if err_msg == '' else err_msg
  243. return xp.testing.assert_close(actual, desired, rtol=0, atol=0, equal_nan=True,
  244. check_dtype=False, msg=err_msg)
  245. # JAX uses `np.testing`
  246. return np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
  247. def xp_assert_close(actual, desired, *, rtol=None, atol=0, check_namespace=True,
  248. check_dtype=True, check_shape=True, check_0d=True,
  249. err_msg='', xp=None):
  250. __tracebackhide__ = True # Hide traceback for py.test
  251. actual, desired, xp = _strict_check(
  252. actual, desired, xp,
  253. check_namespace=check_namespace, check_dtype=check_dtype,
  254. check_shape=check_shape, check_0d=check_0d
  255. )
  256. floating = xp.isdtype(actual.dtype, ('real floating', 'complex floating'))
  257. if rtol is None and floating:
  258. # multiplier of 4 is used as for `np.float64` this puts the default `rtol`
  259. # roughly half way between sqrt(eps) and the default for
  260. # `numpy.testing.assert_allclose`, 1e-7
  261. rtol = xp.finfo(actual.dtype).eps**0.5 * 4
  262. elif rtol is None:
  263. rtol = 1e-7
  264. if is_cupy(xp):
  265. return xp.testing.assert_allclose(actual, desired, rtol=rtol,
  266. atol=atol, err_msg=err_msg)
  267. elif is_torch(xp):
  268. err_msg = None if err_msg == '' else err_msg
  269. return xp.testing.assert_close(actual, desired, rtol=rtol, atol=atol,
  270. equal_nan=True, check_dtype=False, msg=err_msg)
  271. # JAX uses `np.testing`
  272. return np.testing.assert_allclose(actual, desired, rtol=rtol,
  273. atol=atol, err_msg=err_msg)
  274. def xp_assert_close_nulp(actual, desired, *, nulp=1, check_namespace=True,
  275. check_dtype=True, check_shape=True, check_0d=True,
  276. err_msg='', xp=None):
  277. __tracebackhide__ = True # Hide traceback for py.test
  278. actual, desired, xp = _strict_check(
  279. actual, desired, xp,
  280. check_namespace=check_namespace, check_dtype=check_dtype,
  281. check_shape=check_shape, check_0d=check_0d
  282. )
  283. actual, desired = map(_xp_copy_to_numpy, (actual, desired))
  284. return np.testing.assert_array_almost_equal_nulp(actual, desired, nulp=nulp)
  285. def xp_assert_less(actual, desired, *, check_namespace=True, check_dtype=True,
  286. check_shape=True, check_0d=True, err_msg='', verbose=True, xp=None):
  287. __tracebackhide__ = True # Hide traceback for py.test
  288. actual, desired, xp = _strict_check(
  289. actual, desired, xp, check_namespace=check_namespace,
  290. check_dtype=check_dtype, check_shape=check_shape,
  291. check_0d=check_0d
  292. )
  293. if is_cupy(xp):
  294. return xp.testing.assert_array_less(actual, desired,
  295. err_msg=err_msg, verbose=verbose)
  296. elif is_torch(xp):
  297. if actual.device.type != 'cpu':
  298. actual = actual.cpu()
  299. if desired.device.type != 'cpu':
  300. desired = desired.cpu()
  301. # JAX uses `np.testing`
  302. return np.testing.assert_array_less(actual, desired,
  303. err_msg=err_msg, verbose=verbose)
  304. def assert_array_almost_equal(actual, desired, decimal=6, *args, **kwds):
  305. """Backwards compatible replacement. In new code, use xp_assert_close instead.
  306. """
  307. rtol, atol = 0, 1.5*10**(-decimal)
  308. return xp_assert_close(actual, desired,
  309. atol=atol, rtol=rtol, check_dtype=False, check_shape=False,
  310. *args, **kwds)
  311. def assert_almost_equal(actual, desired, decimal=7, *args, **kwds):
  312. """Backwards compatible replacement. In new code, use xp_assert_close instead.
  313. """
  314. rtol, atol = 0, 1.5*10**(-decimal)
  315. return xp_assert_close(actual, desired,
  316. atol=atol, rtol=rtol, check_dtype=False, check_shape=False,
  317. *args, **kwds)
  318. def xp_unsupported_param_msg(param: Any) -> str:
  319. return f'Providing {param!r} is only supported for numpy arrays.'
  320. def is_complex(x: Array, xp: ModuleType) -> bool:
  321. return xp.isdtype(x.dtype, 'complex floating')
  322. def get_native_namespace_name(xp: ModuleType) -> str:
  323. """Return name for native namespace (without array_api_compat prefix)."""
  324. name = xp.__name__
  325. return name.removeprefix(f"{_compat_module_name()}.")
  326. def scipy_namespace_for(xp: ModuleType) -> ModuleType | None:
  327. """Return the `scipy`-like namespace of a non-NumPy backend
  328. That is, return the namespace corresponding with backend `xp` that contains
  329. `scipy` sub-namespaces like `linalg` and `special`. If no such namespace
  330. exists, return ``None``. Useful for dispatching.
  331. """
  332. if is_cupy(xp):
  333. import cupyx # type: ignore[import-not-found,import-untyped]
  334. return cupyx.scipy
  335. if is_jax(xp):
  336. import jax # type: ignore[import-not-found]
  337. return jax.scipy
  338. if is_torch(xp):
  339. return xp
  340. return None
  341. # maybe use `scipy.linalg` if/when array API support is added
  342. def xp_vector_norm(x: Array, /, *,
  343. axis: int | tuple[int] | None = None,
  344. keepdims: bool = False,
  345. ord: int | float = 2,
  346. xp: ModuleType | None = None) -> Array:
  347. xp = array_namespace(x) if xp is None else xp
  348. if SCIPY_ARRAY_API:
  349. # check for optional `linalg` extension
  350. if hasattr(xp, 'linalg'):
  351. return xp.linalg.vector_norm(x, axis=axis, keepdims=keepdims, ord=ord)
  352. else:
  353. if ord != 2:
  354. raise ValueError(
  355. "only the Euclidean norm (`ord=2`) is currently supported in "
  356. "`xp_vector_norm` for backends not implementing the `linalg` "
  357. "extension."
  358. )
  359. # return (x @ x)**0.5
  360. # or to get the right behavior with nd, complex arrays
  361. return xp.sum(xp.conj(x) * x, axis=axis, keepdims=keepdims)**0.5
  362. else:
  363. # to maintain backwards compatibility
  364. return np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
  365. def xp_ravel(x: Array, /, *, xp: ModuleType | None = None) -> Array:
  366. # Equivalent of np.ravel written in terms of array API
  367. # Even though it's one line, it comes up so often that it's worth having
  368. # this function for readability
  369. xp = array_namespace(x) if xp is None else xp
  370. return xp.reshape(x, (-1,))
  371. def xp_swapaxes(a, axis1, axis2, xp=None):
  372. # Equivalent of np.swapaxes written in terms of array API
  373. xp = array_namespace(a) if xp is None else xp
  374. axes = list(range(a.ndim))
  375. axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
  376. a = xp.permute_dims(a, axes)
  377. return a
  378. # utility to find common dtype with option to force floating
  379. def xp_result_type(*args, force_floating=False, xp):
  380. """
  381. Returns the dtype that results from applying type promotion rules
  382. (see Array API Standard Type Promotion Rules) to the arguments. Augments
  383. standard `result_type` in a few ways:
  384. - There is a `force_floating` argument that ensures that the result type
  385. is floating point, even when all args are integer.
  386. - When a TypeError is raised (e.g. due to an unsupported promotion)
  387. and `force_floating=True`, we define a custom rule: use the result type
  388. of the default float and any other floats passed. See
  389. https://github.com/scipy/scipy/pull/22695/files#r1997905891
  390. for rationale.
  391. - This function accepts array-like iterables, which are immediately converted
  392. to the namespace's arrays before result type calculation. Consequently, the
  393. result dtype may be different when an argument is `1.` vs `[1.]`.
  394. Typically, this function will be called shortly after `array_namespace`
  395. on a subset of the arguments passed to `array_namespace`.
  396. """
  397. # prevent double conversion of iterable to array
  398. # avoid `np.iterable` for torch arrays due to pytorch/pytorch#143334
  399. # don't use `array_api_compat.is_array_api_obj` as it returns True for NumPy scalars
  400. args = [(_asarray(arg, subok=True, xp=xp) if is_torch_array(arg) or np.iterable(arg)
  401. else arg) for arg in args]
  402. args_not_none = [arg for arg in args if arg is not None]
  403. if force_floating:
  404. args_not_none.append(1.0)
  405. if is_numpy(xp) and xp.__version__ < '2.0':
  406. # Follow NEP 50 promotion rules anyway
  407. args_not_none = [arg.dtype if getattr(arg, 'size', 0) == 1 else arg
  408. for arg in args_not_none]
  409. return xp.result_type(*args_not_none)
  410. try: # follow library's preferred promotion rules
  411. return xp.result_type(*args_not_none)
  412. except TypeError: # mixed type promotion isn't defined
  413. if not force_floating:
  414. raise
  415. # use `result_type` of default floating point type and any floats present
  416. # This can be revisited, but right now, the only backends that get here
  417. # are array-api-strict (which is not for production use) and PyTorch
  418. # (due to data-apis/array-api-compat#279).
  419. float_args = []
  420. for arg in args_not_none:
  421. arg_array = xp.asarray(arg) if np.isscalar(arg) else arg
  422. dtype = getattr(arg_array, 'dtype', arg)
  423. if xp.isdtype(dtype, ('real floating', 'complex floating')):
  424. float_args.append(arg)
  425. return xp.result_type(*float_args, xp_default_dtype(xp))
  426. def xp_promote(*args, broadcast=False, force_floating=False, xp):
  427. """
  428. Promotes elements of *args to result dtype, ignoring `None`s.
  429. Includes options for forcing promotion to floating point and
  430. broadcasting the arrays, again ignoring `None`s.
  431. Type promotion rules follow `xp_result_type` instead of `xp.result_type`.
  432. Typically, this function will be called shortly after `array_namespace`
  433. on a subset of the arguments passed to `array_namespace`.
  434. This function accepts array-like iterables, which are immediately converted
  435. to the namespace's arrays before result type calculation. Consequently, the
  436. result dtype may be different when an argument is `1.` vs `[1.]`.
  437. See Also
  438. --------
  439. xp_result_type
  440. """
  441. if not args:
  442. return args
  443. # prevent double conversion of iterable to array
  444. # avoid `np.iterable` for torch arrays due to pytorch/pytorch#143334
  445. # don't use `array_api_compat.is_array_api_obj` as it returns True for NumPy scalars
  446. args = [(_asarray(arg, subok=True, xp=xp) if is_torch_array(arg) or np.iterable(arg)
  447. else arg) for arg in args]
  448. dtype = xp_result_type(*args, force_floating=force_floating, xp=xp)
  449. args = [(_asarray(arg, dtype=dtype, subok=True, xp=xp) if arg is not None else arg)
  450. for arg in args]
  451. if not broadcast:
  452. return args[0] if len(args)==1 else tuple(args)
  453. args_not_none = [arg for arg in args if arg is not None]
  454. # determine result shape
  455. shapes = {arg.shape for arg in args_not_none}
  456. try:
  457. shape = (np.broadcast_shapes(*shapes) if len(shapes) != 1
  458. else args_not_none[0].shape)
  459. except ValueError as e:
  460. message = "Array shapes are incompatible for broadcasting."
  461. raise ValueError(message) from e
  462. out = []
  463. for arg in args:
  464. if arg is None:
  465. out.append(arg)
  466. continue
  467. # broadcast only if needed
  468. # Even if two arguments need broadcasting, this is faster than
  469. # `broadcast_arrays`, especially since we've already determined `shape`
  470. if arg.shape != shape:
  471. kwargs = {'subok': True} if is_numpy(xp) else {}
  472. arg = xp.broadcast_to(arg, shape, **kwargs)
  473. # This is much faster than xp.astype(arg, dtype, copy=False)
  474. if arg.dtype != dtype:
  475. arg = xp.astype(arg, dtype)
  476. out.append(arg)
  477. return out[0] if len(out)==1 else tuple(out)
  478. def xp_float_to_complex(arr: Array, xp: ModuleType | None = None) -> Array:
  479. xp = array_namespace(arr) if xp is None else xp
  480. arr_dtype = arr.dtype
  481. # The standard float dtypes are float32 and float64.
  482. # Convert float32 to complex64,
  483. # and float64 (and non-standard real dtypes) to complex128
  484. if xp.isdtype(arr_dtype, xp.float32):
  485. arr = xp.astype(arr, xp.complex64)
  486. elif xp.isdtype(arr_dtype, 'real floating'):
  487. arr = xp.astype(arr, xp.complex128)
  488. return arr
  489. def xp_default_dtype(xp):
  490. """Query the namespace-dependent default floating-point dtype.
  491. """
  492. if is_torch(xp):
  493. # historically, we allow pytorch to keep its default of float32
  494. return xp.get_default_dtype()
  495. else:
  496. # we default to float64
  497. return xp.float64
  498. ### MArray Helpers ###
  499. def xp_result_device(*args):
  500. """Return the device of an array in `args`, for the purpose of
  501. input-output device propagation.
  502. If there are multiple devices, return an arbitrary one.
  503. If there are no arrays, return None (this typically happens only on NumPy).
  504. """
  505. for arg in args:
  506. # Do not do a duck-type test for the .device attribute, as many backends today
  507. # don't have it yet. See workarouunds in array_api_compat.device().
  508. if is_array_api_obj(arg):
  509. return xp_device(arg)
  510. return None
  511. # np.r_ replacement
  512. def concat_1d(xp: ModuleType | None, *arrays: Iterable[ArrayLike]) -> Array:
  513. """A replacement for `np.r_` as `xp.concat` does not accept python scalars
  514. or 0-D arrays.
  515. """
  516. arys = [xpx.atleast_nd(xp.asarray(a), ndim=1, xp=xp) for a in arrays]
  517. return xp.concat(arys)
  518. def is_marray(xp):
  519. """Returns True if `xp` is an MArray namespace; False otherwise."""
  520. return "marray" in xp.__name__
  521. def _length_nonmasked(x, axis, keepdims=False, xp=None):
  522. xp = array_namespace(x) if xp is None else xp
  523. if is_marray(xp):
  524. if np.iterable(axis):
  525. message = '`axis` must be an integer or None for use with `MArray`.'
  526. raise NotImplementedError(message)
  527. return xp.astype(xp.count(x, axis=axis, keepdims=keepdims), x.dtype)
  528. return (xp_size(x) if axis is None else
  529. # compact way to deal with axis tuples or ints
  530. int(np.prod(np.asarray(x.shape)[np.asarray(axis)])))
  531. def _share_masks(*args, xp):
  532. if is_marray(xp):
  533. mask = functools.reduce(operator.or_, (arg.mask for arg in args))
  534. args = [xp.asarray(arg.data, mask=mask) for arg in args]
  535. return args[0] if len(args) == 1 else args
  536. ### End MArray Helpers ###
  537. @dataclasses.dataclass(repr=False)
  538. class _XPSphinxCapability:
  539. cpu: bool | None # None if not applicable
  540. gpu: bool | None
  541. warnings: list[str] = dataclasses.field(default_factory=list)
  542. def _render(self, value):
  543. if value is None:
  544. return "n/a"
  545. if not value:
  546. return "⛔"
  547. if self.warnings:
  548. res = "⚠️ " + '; '.join(self.warnings)
  549. assert len(res) <= 20, "Warnings too long"
  550. return res
  551. return "✅"
  552. def __str__(self):
  553. cpu = self._render(self.cpu)
  554. gpu = self._render(self.gpu)
  555. return f"{cpu:20} {gpu:20}"
  556. def _make_sphinx_capabilities(
  557. # lists of tuples [(module name, reason), ...]
  558. skip_backends=(), xfail_backends=(),
  559. # @pytest.mark.skip/xfail_xp_backends kwargs
  560. cpu_only=False, np_only=False, out_of_scope=False, exceptions=(),
  561. # xpx.lazy_xp_backends kwargs
  562. allow_dask_compute=False, jax_jit=True,
  563. # list of tuples [(module name, reason), ...]
  564. warnings = (),
  565. # unused in documentation
  566. reason=None,
  567. ):
  568. if out_of_scope:
  569. return {"out_of_scope": True}
  570. exceptions = set(exceptions)
  571. # Default capabilities
  572. capabilities = {
  573. "numpy": _XPSphinxCapability(cpu=True, gpu=None),
  574. "array_api_strict": _XPSphinxCapability(cpu=True, gpu=None),
  575. "cupy": _XPSphinxCapability(cpu=None, gpu=True),
  576. "torch": _XPSphinxCapability(cpu=True, gpu=True),
  577. "jax.numpy": _XPSphinxCapability(cpu=True, gpu=True,
  578. warnings=[] if jax_jit else ["no JIT"]),
  579. # Note: Dask+CuPy is currently untested and unsupported
  580. "dask.array": _XPSphinxCapability(cpu=True, gpu=None,
  581. warnings=["computes graph"] if allow_dask_compute else []),
  582. }
  583. # documentation doesn't display the reason
  584. for module, _ in list(skip_backends) + list(xfail_backends):
  585. backend = capabilities[module]
  586. if backend.cpu is not None:
  587. backend.cpu = False
  588. if backend.gpu is not None:
  589. backend.gpu = False
  590. for module, backend in capabilities.items():
  591. if np_only and module not in exceptions | {"numpy"}:
  592. if backend.cpu is not None:
  593. backend.cpu = False
  594. if backend.gpu is not None:
  595. backend.gpu = False
  596. elif cpu_only and module not in exceptions and backend.gpu is not None:
  597. backend.gpu = False
  598. for module, warning in warnings:
  599. backend = capabilities[module]
  600. backend.warnings.append(warning)
  601. return capabilities
  602. def _make_capabilities_note(fun_name, capabilities, extra_note=None):
  603. if "out_of_scope" in capabilities:
  604. # It will be better to link to a section of the dev-arrayapi docs
  605. # that explains what is and isn't in-scope, but such a section
  606. # doesn't exist yet. Using :ref:`dev-arrayapi` as a placeholder.
  607. note = f"""
  608. **Array API Standard Support**
  609. `{fun_name}` is not in-scope for support of Python Array API Standard compatible
  610. backends other than NumPy.
  611. See :ref:`dev-arrayapi` for more information.
  612. """
  613. return textwrap.dedent(note)
  614. # Note: deliberately not documenting array-api-strict
  615. note = f"""
  616. **Array API Standard Support**
  617. `{fun_name}` has experimental support for Python Array API Standard compatible
  618. backends in addition to NumPy. Please consider testing these features
  619. by setting an environment variable ``SCIPY_ARRAY_API=1`` and providing
  620. CuPy, PyTorch, JAX, or Dask arrays as array arguments. The following
  621. combinations of backend and device (or other capability) are supported.
  622. ==================== ==================== ====================
  623. Library CPU GPU
  624. ==================== ==================== ====================
  625. NumPy {capabilities['numpy'] }
  626. CuPy {capabilities['cupy'] }
  627. PyTorch {capabilities['torch'] }
  628. JAX {capabilities['jax.numpy'] }
  629. Dask {capabilities['dask.array'] }
  630. ==================== ==================== ====================
  631. """ + (extra_note or "") + " See :ref:`dev-arrayapi` for more information."
  632. return textwrap.dedent(note)
  633. def xp_capabilities(
  634. *,
  635. # Alternative capabilities table.
  636. # Used only for testing this decorator.
  637. capabilities_table=None,
  638. # Generate pytest.mark.skip/xfail_xp_backends.
  639. # See documentation in conftest.py.
  640. # lists of tuples [(module name, reason), ...]
  641. skip_backends=(), xfail_backends=(),
  642. cpu_only=False, np_only=False, reason=None,
  643. out_of_scope=False, exceptions=(),
  644. # lists of tuples [(module name, reason), ...]
  645. warnings=(),
  646. # xpx.testing.lazy_xp_function kwargs.
  647. # Refer to array-api-extra documentation.
  648. allow_dask_compute=False, jax_jit=True,
  649. # Extra note to inject into the docstring
  650. extra_note=None,
  651. ):
  652. """Decorator for a function that states its support among various
  653. Array API compatible backends.
  654. This decorator has two effects:
  655. 1. It allows tagging tests with ``@make_xp_test_case`` or
  656. ``make_xp_pytest_param`` (see below) to automatically generate
  657. SKIP/XFAIL markers and perform additional backend-specific
  658. testing, such as extra validation for Dask and JAX;
  659. 2. It automatically adds a note to the function's docstring, containing
  660. a table matching what has been tested.
  661. See Also
  662. --------
  663. make_xp_test_case
  664. make_xp_pytest_param
  665. array_api_extra.testing.lazy_xp_function
  666. """
  667. capabilities_table = (xp_capabilities_table if capabilities_table is None
  668. else capabilities_table)
  669. if out_of_scope:
  670. np_only = True
  671. capabilities = dict(
  672. skip_backends=skip_backends,
  673. xfail_backends=xfail_backends,
  674. cpu_only=cpu_only,
  675. np_only=np_only,
  676. out_of_scope=out_of_scope,
  677. reason=reason,
  678. exceptions=exceptions,
  679. allow_dask_compute=allow_dask_compute,
  680. jax_jit=jax_jit,
  681. warnings=warnings,
  682. )
  683. sphinx_capabilities = _make_sphinx_capabilities(**capabilities)
  684. def decorator(f):
  685. # Don't use a wrapper, as in some cases @xp_capabilities is
  686. # applied to a ufunc
  687. capabilities_table[f] = capabilities
  688. note = _make_capabilities_note(f.__name__, sphinx_capabilities, extra_note)
  689. doc = FunctionDoc(f)
  690. doc['Notes'].append(note)
  691. doc = str(doc).split("\n", 1)[1].lstrip(" \n") # remove signature
  692. try:
  693. f.__doc__ = doc
  694. except AttributeError:
  695. # Can't update __doc__ on ufuncs if SciPy
  696. # was compiled against NumPy < 2.2.
  697. pass
  698. return f
  699. return decorator
  700. def make_xp_test_case(*funcs, capabilities_table=None):
  701. capabilities_table = (xp_capabilities_table if capabilities_table is None
  702. else capabilities_table)
  703. """Generate pytest decorator for a test function that tests functionality
  704. of one or more Array API compatible functions.
  705. Read the parameters of the ``@xp_capabilities`` decorator applied to the
  706. listed functions and:
  707. - Generate the ``@pytest.mark.skip_xp_backends`` and
  708. ``@pytest.mark.xfail_xp_backends`` decorators
  709. for the decorated test function
  710. - Tag the function with `xpx.testing.lazy_xp_function`
  711. Example::
  712. @make_xp_test_case(f1)
  713. def test_f1(xp):
  714. ...
  715. @make_xp_test_case(f2)
  716. def test_f2(xp):
  717. ...
  718. @make_xp_test_case(f1, f2)
  719. def test_f1_and_f2(xp):
  720. ...
  721. The above is equivalent to::
  722. @pytest.mark.skip_xp_backends(...)
  723. @pytest.mark.skip_xp_backends(...)
  724. @pytest.mark.xfail_xp_backends(...)
  725. @pytest.mark.xfail_xp_backends(...)
  726. def test_f1(xp):
  727. ...
  728. etc., where the arguments of ``skip_xp_backends`` and ``xfail_xp_backends`` are
  729. determined by the ``@xp_capabilities`` decorator applied to the functions.
  730. See Also
  731. --------
  732. xp_capabilities
  733. make_xp_pytest_marks
  734. make_xp_pytest_param
  735. array_api_extra.testing.lazy_xp_function
  736. """
  737. marks = make_xp_pytest_marks(*funcs, capabilities_table=capabilities_table)
  738. return lambda func: functools.reduce(lambda f, g: g(f), marks, func)
  739. def make_xp_pytest_param(func, *args, capabilities_table=None):
  740. """Variant of ``make_xp_test_case`` that returns a pytest.param for a function,
  741. with all necessary skip_xp_backends and xfail_xp_backends marks applied::
  742. @pytest.mark.parametrize(
  743. "func", [make_xp_pytest_param(f1), make_xp_pytest_param(f2)]
  744. )
  745. def test(func, xp):
  746. ...
  747. The above is equivalent to::
  748. @pytest.mark.parametrize(
  749. "func", [
  750. pytest.param(f1, marks=[
  751. pytest.mark.skip_xp_backends(...),
  752. pytest.mark.xfail_xp_backends(...), ...]),
  753. pytest.param(f2, marks=[
  754. pytest.mark.skip_xp_backends(...),
  755. pytest.mark.xfail_xp_backends(...), ...]),
  756. )
  757. def test(func, xp):
  758. ...
  759. Parameters
  760. ----------
  761. func : Callable
  762. Function to be tested. It must be decorated with ``@xp_capabilities``.
  763. *args : Any, optional
  764. Extra pytest parameters for the use case, e.g.::
  765. @pytest.mark.parametrize("func,verb", [
  766. make_xp_pytest_param(f1, "hello"),
  767. make_xp_pytest_param(f2, "world")])
  768. def test(func, verb, xp):
  769. # iterates on (func=f1, verb="hello")
  770. # and (func=f2, verb="world")
  771. See Also
  772. --------
  773. xp_capabilities
  774. make_xp_test_case
  775. make_xp_pytest_marks
  776. array_api_extra.testing.lazy_xp_function
  777. """
  778. import pytest
  779. marks = make_xp_pytest_marks(func, capabilities_table=capabilities_table)
  780. return pytest.param(func, *args, marks=marks, id=func.__name__)
  781. def make_xp_pytest_marks(*funcs, capabilities_table=None):
  782. """Variant of ``make_xp_test_case`` that returns a list of pytest marks,
  783. which can be used with the module-level `pytestmark = ...` variable::
  784. pytestmark = make_xp_pytest_marks(f1, f2)
  785. def test(xp):
  786. ...
  787. In this example, the whole test module is dedicated to testing `f1` or `f2`,
  788. and the two functions have the same capabilities, so it's unnecessary to
  789. cherry-pick which test tests which function.
  790. The above is equivalent to::
  791. pytestmark = [
  792. pytest.mark.skip_xp_backends(...),
  793. pytest.mark.xfail_xp_backends(...), ...]),
  794. ]
  795. def test(xp):
  796. ...
  797. See Also
  798. --------
  799. xp_capabilities
  800. make_xp_test_case
  801. make_xp_pytest_param
  802. array_api_extra.testing.lazy_xp_function
  803. """
  804. capabilities_table = (xp_capabilities_table if capabilities_table is None
  805. else capabilities_table)
  806. import pytest
  807. marks = []
  808. for func in funcs:
  809. capabilities = capabilities_table[func]
  810. exceptions = capabilities['exceptions']
  811. reason = capabilities['reason']
  812. if capabilities['cpu_only']:
  813. marks.append(pytest.mark.skip_xp_backends(
  814. cpu_only=True, exceptions=exceptions, reason=reason))
  815. if capabilities['np_only']:
  816. marks.append(pytest.mark.skip_xp_backends(
  817. np_only=True, exceptions=exceptions, reason=reason))
  818. for mod_name, reason in capabilities['skip_backends']:
  819. marks.append(pytest.mark.skip_xp_backends(mod_name, reason=reason))
  820. for mod_name, reason in capabilities['xfail_backends']:
  821. marks.append(pytest.mark.xfail_xp_backends(mod_name, reason=reason))
  822. lazy_kwargs = {k: capabilities[k]
  823. for k in ('allow_dask_compute', 'jax_jit')}
  824. lazy_xp_function(func, **lazy_kwargs)
  825. return marks
  826. # Is it OK to have a dictionary that is mutated (once upon import) in many places?
  827. xp_capabilities_table = {} # type: ignore[var-annotated]
  828. def xp_device_type(a: Array) -> Literal["cpu", "cuda", None]:
  829. if is_numpy_array(a):
  830. return "cpu"
  831. if is_cupy_array(a):
  832. return "cuda"
  833. if is_torch_array(a):
  834. # TODO this can return other backends e.g. tpu but they're unsupported in scipy
  835. return a.device.type
  836. if is_jax_array(a):
  837. # TODO this can return other backends e.g. tpu but they're unsupported in scipy
  838. return "cuda" if (p := a.device.platform) == "gpu" else p
  839. if is_dask_array(a):
  840. return xp_device_type(a._meta)
  841. # array-api-strict is a stand-in for unknown libraries; don't special-case it
  842. return None