_util.py 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251
  1. import re
  2. from contextlib import contextmanager
  3. import functools
  4. import operator
  5. import warnings
  6. import numbers
  7. from collections import namedtuple
  8. import inspect
  9. import math
  10. import os
  11. import sys
  12. import textwrap
  13. from types import ModuleType
  14. from typing import Literal, TypeAlias, TypeVar
  15. import numpy as np
  16. from scipy._lib._array_api import (Array, array_namespace, is_lazy_array, is_numpy,
  17. is_marray, xp_size, xp_result_device, xp_result_type)
  18. from scipy._lib._docscrape import FunctionDoc, Parameter
  19. from scipy._lib._sparse import issparse
  20. from numpy.exceptions import AxisError
  21. np_long: type
  22. np_ulong: type
  23. if np.lib.NumpyVersion(np.__version__) >= "2.0.0.dev0":
  24. try:
  25. with warnings.catch_warnings():
  26. warnings.filterwarnings(
  27. "ignore",
  28. r".*In the future `np\.long` will be defined as.*",
  29. FutureWarning,
  30. )
  31. np_long = np.long # type: ignore[attr-defined]
  32. np_ulong = np.ulong # type: ignore[attr-defined]
  33. except AttributeError:
  34. np_long = np.int_
  35. np_ulong = np.uint
  36. else:
  37. np_long = np.int_
  38. np_ulong = np.uint
  39. IntNumber = int | np.integer
  40. DecimalNumber = float | np.floating | np.integer
  41. copy_if_needed: bool | None
  42. if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
  43. copy_if_needed = None
  44. elif np.lib.NumpyVersion(np.__version__) < "1.28.0":
  45. copy_if_needed = False
  46. else:
  47. # 2.0.0 dev versions, handle cases where copy may or may not exist
  48. try:
  49. np.array([1]).__array__(copy=None) # type: ignore[call-overload]
  50. copy_if_needed = None
  51. except TypeError:
  52. copy_if_needed = False
  53. # Wrapped function for inspect.signature for compatibility with Python 3.14+
  54. # See gh-23913
  55. #
  56. # PEP 649/749 allows for underfined annotations at runtime, and added the
  57. # `annotation_format` parameter to handle these cases.
  58. # `annotationlib.Format.FORWARDREF` is the closest to previous behavior,
  59. # returning ForwardRef objects fornew undefined annotations cases.
  60. #
  61. # Consider dropping this wrapper when support for Python 3.13 is dropped.
  62. if sys.version_info >= (3, 14):
  63. import annotationlib
  64. def wrapped_inspect_signature(callable):
  65. """Get a signature object for the passed callable."""
  66. return inspect.signature(callable,
  67. annotation_format=annotationlib.Format.FORWARDREF)
  68. else:
  69. wrapped_inspect_signature = inspect.signature
  70. _RNG: TypeAlias = np.random.Generator | np.random.RandomState
  71. SeedType: TypeAlias = IntNumber | _RNG | None
  72. GeneratorType = TypeVar("GeneratorType", bound=_RNG)
  73. def _lazyselect(condlist, choicelist, arrays, default=0):
  74. """
  75. Mimic `np.select(condlist, choicelist)`.
  76. Notice, it assumes that all `arrays` are of the same shape or can be
  77. broadcasted together.
  78. All functions in `choicelist` must accept array arguments in the order
  79. given in `arrays` and must return an array of the same shape as broadcasted
  80. `arrays`.
  81. Examples
  82. --------
  83. >>> import numpy as np
  84. >>> x = np.arange(6)
  85. >>> np.select([x <3, x > 3], [x**2, x**3], default=0)
  86. array([ 0, 1, 4, 0, 64, 125])
  87. >>> _lazyselect([x < 3, x > 3], [lambda x: x**2, lambda x: x**3], (x,))
  88. array([ 0., 1., 4., 0., 64., 125.])
  89. >>> a = -np.ones_like(x)
  90. >>> _lazyselect([x < 3, x > 3],
  91. ... [lambda x, a: x**2, lambda x, a: a * x**3],
  92. ... (x, a), default=np.nan)
  93. array([ 0., 1., 4., nan, -64., -125.])
  94. """
  95. arrays = np.broadcast_arrays(*arrays)
  96. tcode = np.mintypecode([a.dtype.char for a in arrays])
  97. out = np.full(np.shape(arrays[0]), fill_value=default, dtype=tcode)
  98. for func, cond in zip(choicelist, condlist):
  99. if np.all(cond is False):
  100. continue
  101. cond, _ = np.broadcast_arrays(cond, arrays[0])
  102. temp = tuple(np.extract(cond, arr) for arr in arrays)
  103. np.place(out, cond, func(*temp))
  104. return out
  105. def _aligned_zeros(shape, dtype=float, order="C", align=None):
  106. """Allocate a new ndarray with aligned memory.
  107. Primary use case for this currently is working around a f2py issue
  108. in NumPy 1.9.1, where dtype.alignment is such that np.zeros() does
  109. not necessarily create arrays aligned up to it.
  110. """
  111. dtype = np.dtype(dtype)
  112. if align is None:
  113. align = dtype.alignment
  114. if not hasattr(shape, '__len__'):
  115. shape = (shape,)
  116. size = functools.reduce(operator.mul, shape) * dtype.itemsize
  117. buf = np.empty(size + align + 1, np.uint8)
  118. offset = buf.__array_interface__['data'][0] % align
  119. if offset != 0:
  120. offset = align - offset
  121. # Note: slices producing 0-size arrays do not necessarily change
  122. # data pointer --- so we use and allocate size+1
  123. buf = buf[offset:offset+size+1][:-1]
  124. data = np.ndarray(shape, dtype, buf, order=order)
  125. data.fill(0)
  126. return data
  127. def _prune_array(array):
  128. """Return an array equivalent to the input array. If the input
  129. array is a view of a much larger array, copy its contents to a
  130. newly allocated array. Otherwise, return the input unchanged.
  131. """
  132. if array.base is not None and array.size < array.base.size // 2:
  133. return array.copy()
  134. return array
  135. def float_factorial(n: int) -> float:
  136. """Compute the factorial and return as a float
  137. Returns infinity when result is too large for a double
  138. """
  139. return float(math.factorial(n)) if n < 171 else np.inf
  140. _rng_desc = (
  141. r"""If `rng` is passed by keyword, types other than `numpy.random.Generator` are
  142. passed to `numpy.random.default_rng` to instantiate a ``Generator``.
  143. If `rng` is already a ``Generator`` instance, then the provided instance is
  144. used. Specify `rng` for repeatable function behavior.
  145. If this argument is passed by position or `{old_name}` is passed by keyword,
  146. legacy behavior for the argument `{old_name}` applies:
  147. - If `{old_name}` is None (or `numpy.random`), the `numpy.random.RandomState`
  148. singleton is used.
  149. - If `{old_name}` is an int, a new ``RandomState`` instance is used,
  150. seeded with `{old_name}`.
  151. - If `{old_name}` is already a ``Generator`` or ``RandomState`` instance then
  152. that instance is used.
  153. .. versionchanged:: 1.15.0
  154. As part of the `SPEC-007 <https://scientific-python.org/specs/spec-0007/>`_
  155. transition from use of `numpy.random.RandomState` to
  156. `numpy.random.Generator`, this keyword was changed from `{old_name}` to `rng`.
  157. For an interim period, both keywords will continue to work, although only one
  158. may be specified at a time. After the interim period, function calls using the
  159. `{old_name}` keyword will emit warnings. The behavior of both `{old_name}` and
  160. `rng` are outlined above, but only the `rng` keyword should be used in new code.
  161. """
  162. )
  163. # SPEC 7
  164. def _transition_to_rng(old_name, *, position_num=None, end_version=None,
  165. replace_doc=True):
  166. """Example decorator to transition from old PRNG usage to new `rng` behavior
  167. Suppose the decorator is applied to a function that used to accept parameter
  168. `old_name='random_state'` either by keyword or as a positional argument at
  169. `position_num=1`. At the time of application, the name of the argument in the
  170. function signature is manually changed to the new name, `rng`. If positional
  171. use was allowed before, this is not changed.*
  172. - If the function is called with both `random_state` and `rng`, the decorator
  173. raises an error.
  174. - If `random_state` is provided as a keyword argument, the decorator passes
  175. `random_state` to the function's `rng` argument as a keyword. If `end_version`
  176. is specified, the decorator will emit a `DeprecationWarning` about the
  177. deprecation of keyword `random_state`.
  178. - If `random_state` is provided as a positional argument, the decorator passes
  179. `random_state` to the function's `rng` argument by position. If `end_version`
  180. is specified, the decorator will emit a `FutureWarning` about the changing
  181. interpretation of the argument.
  182. - If `rng` is provided as a keyword argument, the decorator validates `rng` using
  183. `numpy.random.default_rng` before passing it to the function.
  184. - If `end_version` is specified and neither `random_state` nor `rng` is provided
  185. by the user, the decorator checks whether `np.random.seed` has been used to set
  186. the global seed. If so, it emits a `FutureWarning`, noting that usage of
  187. `numpy.random.seed` will eventually have no effect. Either way, the decorator
  188. calls the function without explicitly passing the `rng` argument.
  189. If `end_version` is specified, a user must pass `rng` as a keyword to avoid
  190. warnings.
  191. After the deprecation period, the decorator can be removed, and the function
  192. can simply validate the `rng` argument by calling `np.random.default_rng(rng)`.
  193. * A `FutureWarning` is emitted when the PRNG argument is used by
  194. position. It indicates that the "Hinsen principle" (same
  195. code yielding different results in two versions of the software)
  196. will be violated, unless positional use is deprecated. Specifically:
  197. - If `None` is passed by position and `np.random.seed` has been used,
  198. the function will change from being seeded to being unseeded.
  199. - If an integer is passed by position, the random stream will change.
  200. - If `np.random` or an instance of `RandomState` is passed by position,
  201. an error will be raised.
  202. We suggest that projects consider deprecating positional use of
  203. `random_state`/`rng` (i.e., change their function signatures to
  204. ``def my_func(..., *, rng=None)``); that might not make sense
  205. for all projects, so this SPEC does not make that
  206. recommendation, neither does this decorator enforce it.
  207. Parameters
  208. ----------
  209. old_name : str
  210. The old name of the PRNG argument (e.g. `seed` or `random_state`).
  211. position_num : int, optional
  212. The (0-indexed) position of the old PRNG argument (if accepted by position).
  213. Maintainers are welcome to eliminate this argument and use, for example,
  214. `inspect`, if preferred.
  215. end_version : str, optional
  216. The full version number of the library when the behavior described in
  217. `DeprecationWarning`s and `FutureWarning`s will take effect. If left
  218. unspecified, no warnings will be emitted by the decorator.
  219. replace_doc : bool, default: True
  220. Whether the decorator should replace the documentation for parameter `rng` with
  221. `_rng_desc` (defined above), which documents both new `rng` keyword behavior
  222. and typical legacy `random_state`/`seed` behavior. If True, manually replace
  223. the first paragraph of the function's old `random_state`/`seed` documentation
  224. with the desired *final* `rng` documentation; this way, no changes to
  225. documentation are needed when the decorator is removed. Documentation of `rng`
  226. after the first blank line is preserved. Use False if the function's old
  227. `random_state`/`seed` behavior does not match that described by `_rng_desc`.
  228. """
  229. NEW_NAME = "rng"
  230. cmn_msg = (
  231. "To silence this warning and ensure consistent behavior in SciPy "
  232. f"{end_version}, control the RNG using argument `{NEW_NAME}`. Arguments passed "
  233. f"to keyword `{NEW_NAME}` will be validated by `np.random.default_rng`, so the "
  234. "behavior corresponding with a given value may change compared to use of "
  235. f"`{old_name}`. For example, "
  236. "1) `None` will result in unpredictable random numbers, "
  237. "2) an integer will result in a different stream of random numbers, (with the "
  238. "same distribution), and "
  239. "3) `np.random` or `RandomState` instances will result in an error. "
  240. "See the documentation of `default_rng` for more information."
  241. )
  242. def decorator(fun):
  243. @functools.wraps(fun)
  244. def wrapper(*args, **kwargs):
  245. # Determine how PRNG was passed
  246. as_old_kwarg = old_name in kwargs
  247. as_new_kwarg = NEW_NAME in kwargs
  248. as_pos_arg = position_num is not None and len(args) >= position_num + 1
  249. emit_warning = end_version is not None
  250. # Can only specify PRNG one of the three ways
  251. if int(as_old_kwarg) + int(as_new_kwarg) + int(as_pos_arg) > 1:
  252. message = (
  253. f"{fun.__name__}() got multiple values for "
  254. f"argument now known as `{NEW_NAME}`. Specify one of "
  255. f"`{NEW_NAME}` or `{old_name}`."
  256. )
  257. raise TypeError(message)
  258. # Check whether global random state has been set
  259. global_seed_set = np.random.mtrand._rand._bit_generator._seed_seq is None
  260. if as_old_kwarg: # warn about deprecated use of old kwarg
  261. kwargs[NEW_NAME] = kwargs.pop(old_name)
  262. if emit_warning:
  263. message = (
  264. f"Use of keyword argument `{old_name}` is "
  265. f"deprecated and replaced by `{NEW_NAME}`. "
  266. f"Support for `{old_name}` will be removed "
  267. f"in SciPy {end_version}. "
  268. ) + cmn_msg
  269. warnings.warn(message, DeprecationWarning, stacklevel=2)
  270. elif as_pos_arg:
  271. # Warn about changing meaning of positional arg
  272. # Note that this decorator does not deprecate positional use of the
  273. # argument; it only warns that the behavior will change in the future.
  274. # Simultaneously transitioning to keyword-only use is another option.
  275. arg = args[position_num]
  276. # If the argument is None and the global seed wasn't set, or if the
  277. # argument is one of a few new classes, the user will not notice change
  278. # in behavior.
  279. ok_classes = (
  280. np.random.Generator,
  281. np.random.SeedSequence,
  282. np.random.BitGenerator,
  283. )
  284. if (arg is None and not global_seed_set) or isinstance(arg, ok_classes):
  285. pass
  286. elif emit_warning:
  287. message = (
  288. f"Positional use of `{NEW_NAME}` (formerly known as "
  289. f"`{old_name}`) is still allowed, but the behavior is "
  290. "changing: the argument will be normalized using "
  291. f"`np.random.default_rng` beginning in SciPy {end_version}, "
  292. "and the resulting `Generator` will be used to generate "
  293. "random numbers."
  294. ) + cmn_msg
  295. warnings.warn(message, FutureWarning, stacklevel=2)
  296. elif as_new_kwarg: # no warnings; this is the preferred use
  297. # After the removal of the decorator, normalization with
  298. # np.random.default_rng will be done inside the decorated function
  299. kwargs[NEW_NAME] = np.random.default_rng(kwargs[NEW_NAME])
  300. elif global_seed_set and emit_warning:
  301. # Emit FutureWarning if `np.random.seed` was used and no PRNG was passed
  302. message = (
  303. "The NumPy global RNG was seeded by calling "
  304. f"`np.random.seed`. Beginning in {end_version}, this "
  305. "function will no longer use the global RNG."
  306. ) + cmn_msg
  307. warnings.warn(message, FutureWarning, stacklevel=2)
  308. return fun(*args, **kwargs)
  309. # Add the old parameter name to the function signature
  310. wrapped_signature = inspect.signature(fun)
  311. wrapper.__signature__ = wrapped_signature.replace(parameters=[
  312. *wrapped_signature.parameters.values(),
  313. inspect.Parameter(old_name, inspect.Parameter.KEYWORD_ONLY, default=None),
  314. ])
  315. if replace_doc:
  316. doc = FunctionDoc(wrapper)
  317. parameter_names = [param.name for param in doc['Parameters']]
  318. if 'rng' in parameter_names:
  319. _type = "{None, int, `numpy.random.Generator`}, optional"
  320. _desc = _rng_desc.replace("{old_name}", old_name)
  321. old_doc = doc['Parameters'][parameter_names.index('rng')].desc
  322. old_doc_keep = old_doc[old_doc.index("") + 1:] if "" in old_doc else []
  323. new_doc = [_desc] + old_doc_keep
  324. _rng_parameter_doc = Parameter('rng', _type, new_doc)
  325. doc['Parameters'][parameter_names.index('rng')] = _rng_parameter_doc
  326. doc = str(doc).split("\n", 1)[1].lstrip(" \n") # remove signature
  327. wrapper.__doc__ = str(doc)
  328. return wrapper
  329. return decorator
  330. # copy-pasted from scikit-learn utils/validation.py
  331. def check_random_state(seed):
  332. """Turn `seed` into a `np.random.RandomState` instance.
  333. Parameters
  334. ----------
  335. seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
  336. If `seed` is None (or `np.random`), the `numpy.random.RandomState`
  337. singleton is used.
  338. If `seed` is an int, a new ``RandomState`` instance is used,
  339. seeded with `seed`.
  340. If `seed` is already a ``Generator`` or ``RandomState`` instance then
  341. that instance is used.
  342. Returns
  343. -------
  344. seed : {`numpy.random.Generator`, `numpy.random.RandomState`}
  345. Random number generator.
  346. """
  347. if seed is None or seed is np.random:
  348. return np.random.mtrand._rand
  349. if isinstance(seed, numbers.Integral | np.integer):
  350. return np.random.RandomState(seed)
  351. if isinstance(seed, np.random.RandomState | np.random.Generator):
  352. return seed
  353. raise ValueError(f"'{seed}' cannot be used to seed a numpy.random.RandomState"
  354. " instance")
  355. def _asarray_validated(a, check_finite=True,
  356. sparse_ok=False, objects_ok=False, mask_ok=False,
  357. as_inexact=False):
  358. """
  359. Helper function for SciPy argument validation.
  360. Many SciPy linear algebra functions do support arbitrary array-like
  361. input arguments. Examples of commonly unsupported inputs include
  362. matrices containing inf/nan, sparse matrix representations, and
  363. matrices with complicated elements.
  364. Parameters
  365. ----------
  366. a : array_like
  367. The array-like input.
  368. check_finite : bool, optional
  369. Whether to check that the input matrices contain only finite numbers.
  370. Disabling may give a performance gain, but may result in problems
  371. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  372. Default: True
  373. sparse_ok : bool, optional
  374. True if scipy sparse matrices are allowed.
  375. objects_ok : bool, optional
  376. True if arrays with dype('O') are allowed.
  377. mask_ok : bool, optional
  378. True if masked arrays are allowed.
  379. as_inexact : bool, optional
  380. True to convert the input array to a np.inexact dtype.
  381. Returns
  382. -------
  383. ret : ndarray
  384. The converted validated array.
  385. """
  386. if not sparse_ok:
  387. if issparse(a):
  388. msg = ('Sparse arrays/matrices are not supported by this function. '
  389. 'Perhaps one of the `scipy.sparse.linalg` functions '
  390. 'would work instead.')
  391. raise ValueError(msg)
  392. if not mask_ok:
  393. if np.ma.isMaskedArray(a):
  394. raise ValueError('masked arrays are not supported')
  395. toarray = np.asarray_chkfinite if check_finite else np.asarray
  396. a = toarray(a)
  397. if not objects_ok:
  398. if a.dtype is np.dtype('O'):
  399. raise ValueError('object arrays are not supported')
  400. if as_inexact:
  401. if not np.issubdtype(a.dtype, np.inexact):
  402. a = toarray(a, dtype=np.float64)
  403. return a
  404. def _validate_int(k, name, minimum=None):
  405. """
  406. Validate a scalar integer.
  407. This function can be used to validate an argument to a function
  408. that expects the value to be an integer. It uses `operator.index`
  409. to validate the value (so, for example, k=2.0 results in a
  410. TypeError).
  411. Parameters
  412. ----------
  413. k : int
  414. The value to be validated.
  415. name : str
  416. The name of the parameter.
  417. minimum : int, optional
  418. An optional lower bound.
  419. """
  420. try:
  421. k = operator.index(k)
  422. except TypeError:
  423. raise TypeError(f'{name} must be an integer.') from None
  424. if minimum is not None and k < minimum:
  425. raise ValueError(f'{name} must be an integer not less '
  426. f'than {minimum}') from None
  427. return k
  428. # Add a replacement for inspect.getfullargspec()/
  429. # The version below is borrowed from Django,
  430. # https://github.com/django/django/pull/4846.
  431. # Note an inconsistency between inspect.getfullargspec(func) and
  432. # inspect.signature(func). If `func` is a bound method, the latter does *not*
  433. # list `self` as a first argument, while the former *does*.
  434. # Hence, cook up a common ground replacement: `getfullargspec_no_self` which
  435. # mimics `inspect.getfullargspec` but does not list `self`.
  436. #
  437. # This way, the caller code does not need to know whether it uses a legacy
  438. # .getfullargspec or a bright and shiny .signature.
  439. FullArgSpec = namedtuple('FullArgSpec',
  440. ['args', 'varargs', 'varkw', 'defaults',
  441. 'kwonlyargs', 'kwonlydefaults', 'annotations'])
  442. def getfullargspec_no_self(func):
  443. """inspect.getfullargspec replacement using inspect.signature.
  444. If func is a bound method, do not list the 'self' parameter.
  445. Parameters
  446. ----------
  447. func : callable
  448. A callable to inspect
  449. Returns
  450. -------
  451. fullargspec : FullArgSpec(args, varargs, varkw, defaults, kwonlyargs,
  452. kwonlydefaults, annotations)
  453. NOTE: if the first argument of `func` is self, it is *not*, I repeat
  454. *not*, included in fullargspec.args.
  455. This is done for consistency between inspect.getargspec() under
  456. Python 2.x, and inspect.signature() under Python 3.x.
  457. """
  458. sig = wrapped_inspect_signature(func)
  459. args = [
  460. p.name for p in sig.parameters.values()
  461. if p.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD,
  462. inspect.Parameter.POSITIONAL_ONLY]
  463. ]
  464. varargs = [
  465. p.name for p in sig.parameters.values()
  466. if p.kind == inspect.Parameter.VAR_POSITIONAL
  467. ]
  468. varargs = varargs[0] if varargs else None
  469. varkw = [
  470. p.name for p in sig.parameters.values()
  471. if p.kind == inspect.Parameter.VAR_KEYWORD
  472. ]
  473. varkw = varkw[0] if varkw else None
  474. defaults = tuple(
  475. p.default for p in sig.parameters.values()
  476. if (p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and
  477. p.default is not p.empty)
  478. ) or None
  479. kwonlyargs = [
  480. p.name for p in sig.parameters.values()
  481. if p.kind == inspect.Parameter.KEYWORD_ONLY
  482. ]
  483. kwdefaults = {p.name: p.default for p in sig.parameters.values()
  484. if p.kind == inspect.Parameter.KEYWORD_ONLY and
  485. p.default is not p.empty}
  486. annotations = {p.name: p.annotation for p in sig.parameters.values()
  487. if p.annotation is not p.empty}
  488. return FullArgSpec(args, varargs, varkw, defaults, kwonlyargs,
  489. kwdefaults or None, annotations)
  490. class _FunctionWrapper:
  491. """
  492. Object to wrap user's function, allowing picklability
  493. """
  494. def __init__(self, f, args):
  495. self.f = f
  496. self.args = [] if args is None else args
  497. def __call__(self, x):
  498. return self.f(x, *self.args)
  499. class _ScalarFunctionWrapper:
  500. """
  501. Object to wrap scalar user function, allowing picklability
  502. """
  503. def __init__(self, f, args=None):
  504. self.f = f
  505. self.args = [] if args is None else args
  506. self.nfev = 0
  507. def __call__(self, x):
  508. # Send a copy because the user may overwrite it.
  509. # The user of this class might want `x` to remain unchanged.
  510. fx = self.f(np.copy(x), *self.args)
  511. self.nfev += 1
  512. # Make sure the function returns a true scalar
  513. if not np.isscalar(fx):
  514. try:
  515. fx = np.asarray(fx).item()
  516. except (TypeError, ValueError) as e:
  517. raise ValueError(
  518. "The user-provided objective function "
  519. "must return a scalar value."
  520. ) from e
  521. return fx
  522. class MapWrapper:
  523. """
  524. Parallelisation wrapper for working with map-like callables, such as
  525. `multiprocessing.Pool.map`.
  526. Parameters
  527. ----------
  528. pool : int or map-like callable
  529. If `pool` is an integer, then it specifies the number of threads to
  530. use for parallelization. If ``int(pool) == 1``, then no parallel
  531. processing is used and the map builtin is used.
  532. If ``pool == -1``, then the pool will utilize all available CPUs.
  533. If `pool` is a map-like callable that follows the same
  534. calling sequence as the built-in map function, then this callable is
  535. used for parallelization.
  536. """
  537. def __init__(self, pool=1):
  538. self.pool = None
  539. self._mapfunc = map
  540. self._own_pool = False
  541. if callable(pool):
  542. self.pool = pool
  543. self._mapfunc = self.pool
  544. else:
  545. from multiprocessing import get_context, get_start_method
  546. method = get_start_method(allow_none=True)
  547. if method is None and os.name=='posix' and sys.version_info < (3, 14):
  548. # Python 3.13 and older used "fork" on posix, which can lead to
  549. # deadlocks. This backports that fix to older Python versions.
  550. method = 'forkserver'
  551. # user supplies a number
  552. if int(pool) == -1:
  553. # use as many processors as possible
  554. self.pool = get_context(method=method).Pool()
  555. self._mapfunc = self.pool.map
  556. self._own_pool = True
  557. elif int(pool) == 1:
  558. pass
  559. elif int(pool) > 1:
  560. # use the number of processors requested
  561. self.pool = get_context(method=method).Pool(processes=int(pool))
  562. self._mapfunc = self.pool.map
  563. self._own_pool = True
  564. else:
  565. raise RuntimeError("Number of workers specified must be -1,"
  566. " an int >= 1, or an object with a 'map' "
  567. "method")
  568. def __enter__(self):
  569. return self
  570. def terminate(self):
  571. if self._own_pool:
  572. self.pool.terminate()
  573. def join(self):
  574. if self._own_pool:
  575. self.pool.join()
  576. def close(self):
  577. if self._own_pool:
  578. self.pool.close()
  579. def __exit__(self, exc_type, exc_value, traceback):
  580. if self._own_pool:
  581. self.pool.close()
  582. self.pool.terminate()
  583. def __call__(self, func, iterable):
  584. # only accept one iterable because that's all Pool.map accepts
  585. try:
  586. return self._mapfunc(func, iterable)
  587. except TypeError as e:
  588. # wrong number of arguments
  589. raise TypeError("The map-like callable must be of the"
  590. " form f(func, iterable)") from e
  591. def _workers_wrapper(func):
  592. """
  593. Wrapper to deal with setup-cleanup of workers outside a user function via a
  594. ContextManager. It saves having to do the setup/tear down with within that
  595. function, which can be messy.
  596. """
  597. @functools.wraps(func)
  598. def inner(*args, **kwds):
  599. kwargs = kwds.copy()
  600. if 'workers' not in kwargs:
  601. _workers = map
  602. elif 'workers' in kwargs and kwargs['workers'] is None:
  603. _workers = map
  604. else:
  605. _workers = kwargs['workers']
  606. with MapWrapper(_workers) as mf:
  607. kwargs['workers'] = mf
  608. return func(*args, **kwargs)
  609. return inner
  610. def rng_integers(gen, low, high=None, size=None, dtype='int64',
  611. endpoint=False):
  612. """
  613. Return random integers from low (inclusive) to high (exclusive), or if
  614. endpoint=True, low (inclusive) to high (inclusive). Replaces
  615. `RandomState.randint` (with endpoint=False) and
  616. `RandomState.random_integers` (with endpoint=True).
  617. Return random integers from the "discrete uniform" distribution of the
  618. specified dtype. If high is None (the default), then results are from
  619. 0 to low.
  620. Parameters
  621. ----------
  622. gen : {None, np.random.RandomState, np.random.Generator}
  623. Random number generator. If None, then the np.random.RandomState
  624. singleton is used.
  625. low : int or array-like of ints
  626. Lowest (signed) integers to be drawn from the distribution (unless
  627. high=None, in which case this parameter is 0 and this value is used
  628. for high).
  629. high : int or array-like of ints
  630. If provided, one above the largest (signed) integer to be drawn from
  631. the distribution (see above for behavior if high=None). If array-like,
  632. must contain integer values.
  633. size : array-like of ints, optional
  634. Output shape. If the given shape is, e.g., (m, n, k), then m * n * k
  635. samples are drawn. Default is None, in which case a single value is
  636. returned.
  637. dtype : {str, dtype}, optional
  638. Desired dtype of the result. All dtypes are determined by their name,
  639. i.e., 'int64', 'int', etc, so byteorder is not available and a specific
  640. precision may have different C types depending on the platform.
  641. The default value is 'int64'.
  642. endpoint : bool, optional
  643. If True, sample from the interval [low, high] instead of the default
  644. [low, high) Defaults to False.
  645. Returns
  646. -------
  647. out: int or ndarray of ints
  648. size-shaped array of random integers from the appropriate distribution,
  649. or a single such random int if size not provided.
  650. """
  651. if isinstance(gen, np.random.Generator):
  652. return gen.integers(low, high=high, size=size, dtype=dtype,
  653. endpoint=endpoint)
  654. else:
  655. if gen is None:
  656. # default is RandomState singleton used by np.random.
  657. gen = np.random.mtrand._rand
  658. if endpoint:
  659. # inclusive of endpoint
  660. # remember that low and high can be arrays, so don't modify in
  661. # place
  662. if high is None:
  663. return gen.randint(low + 1, size=size, dtype=dtype)
  664. if high is not None:
  665. return gen.randint(low, high=high + 1, size=size, dtype=dtype)
  666. # exclusive
  667. return gen.randint(low, high=high, size=size, dtype=dtype)
  668. @contextmanager
  669. def _fixed_default_rng(seed=1638083107694713882823079058616272161):
  670. """Context with a fixed np.random.default_rng seed."""
  671. orig_fun = np.random.default_rng
  672. np.random.default_rng = lambda seed=seed: orig_fun(seed)
  673. try:
  674. yield
  675. finally:
  676. np.random.default_rng = orig_fun
  677. @contextmanager
  678. def ignore_warns(expected_warning, *, match=None):
  679. with warnings.catch_warnings():
  680. warnings.filterwarnings("ignore", match, expected_warning)
  681. yield
  682. def _rng_html_rewrite(func):
  683. """Rewrite the HTML rendering of ``np.random.default_rng``.
  684. This is intended to decorate
  685. ``numpydoc.docscrape_sphinx.SphinxDocString._str_examples``.
  686. Examples are only run by Sphinx when there are plot involved. Even so,
  687. it does not change the result values getting printed.
  688. """
  689. # hexadecimal or number seed, case-insensitive
  690. pattern = re.compile(r'np.random.default_rng\((0x[0-9A-F]+|\d+)\)', re.I)
  691. def _wrapped(*args, **kwargs):
  692. res = func(*args, **kwargs)
  693. lines = [
  694. re.sub(pattern, 'np.random.default_rng()', line)
  695. for line in res
  696. ]
  697. return lines
  698. return _wrapped
  699. def _argmin(a, keepdims=False, axis=None):
  700. """
  701. argmin with a `keepdims` parameter.
  702. See https://github.com/numpy/numpy/issues/8710
  703. If axis is not None, a.shape[axis] must be greater than 0.
  704. """
  705. res = np.argmin(a, axis=axis)
  706. if keepdims and axis is not None:
  707. res = np.expand_dims(res, axis=axis)
  708. return res
  709. def _contains_nan(
  710. a: Array,
  711. nan_policy: Literal["propagate", "raise", "omit"] = "propagate",
  712. *,
  713. xp_omit_okay: bool = False,
  714. xp: ModuleType | None = None,
  715. ) -> Array | bool:
  716. # Regarding `xp_omit_okay`: Temporarily, while `_axis_nan_policy` does not
  717. # handle non-NumPy arrays, most functions that call `_contains_nan` want
  718. # it to raise an error if `nan_policy='omit'` and `xp` is not `np`.
  719. # Some functions support `nan_policy='omit'` natively, so setting this to
  720. # `True` prevents the error from being raised.
  721. policies = {"propagate", "raise", "omit"}
  722. if nan_policy not in policies:
  723. msg = f"nan_policy must be one of {policies}."
  724. raise ValueError(msg)
  725. if xp_size(a) == 0:
  726. return False
  727. if xp is None:
  728. xp = array_namespace(a)
  729. if xp.isdtype(a.dtype, "real floating"):
  730. # Faster and less memory-intensive than xp.any(xp.isnan(a)), and unlike other
  731. # reductions, `max`/`min` won't return NaN unless there is a NaN in the data.
  732. contains_nan = xp.isnan(xp.max(a))
  733. elif xp.isdtype(a.dtype, "complex floating"):
  734. # Typically `real` and `imag` produce views; otherwise, `xp.any(xp.isnan(a))`
  735. # would be more efficient.
  736. contains_nan = xp.isnan(xp.max(xp.real(a))) | xp.isnan(xp.max(xp.imag(a)))
  737. elif is_numpy(xp) and np.issubdtype(a.dtype, object):
  738. contains_nan = False
  739. for el in a.ravel():
  740. # isnan doesn't work on non-numeric elements
  741. if np.issubdtype(type(el), np.number) and np.isnan(el):
  742. contains_nan = True
  743. break
  744. else:
  745. # Only `object` and `inexact` arrays can have NaNs
  746. return False
  747. # The implicit call to bool(contains_nan) must happen after testing
  748. # nan_policy to prevent lazy and device-bound xps from raising in the
  749. # default policy='propagate' case.
  750. if nan_policy == 'raise':
  751. if is_lazy_array(a):
  752. msg = "nan_policy='raise' is not supported for lazy arrays."
  753. raise TypeError(msg)
  754. if contains_nan:
  755. msg = "The input contains nan values"
  756. raise ValueError(msg)
  757. elif nan_policy == 'omit' and not xp_omit_okay and not is_numpy(xp):
  758. if is_lazy_array(a):
  759. msg = "nan_policy='omit' is not supported for lazy arrays."
  760. raise TypeError(msg)
  761. return contains_nan
  762. def _rename_parameter(old_name, new_name, dep_version=None):
  763. """
  764. Generate decorator for backward-compatible keyword renaming.
  765. Apply the decorator generated by `_rename_parameter` to functions with a
  766. recently renamed parameter to maintain backward-compatibility.
  767. After decoration, the function behaves as follows:
  768. If only the new parameter is passed into the function, behave as usual.
  769. If only the old parameter is passed into the function (as a keyword), raise
  770. a DeprecationWarning if `dep_version` is provided, and behave as usual
  771. otherwise.
  772. If both old and new parameters are passed into the function, raise a
  773. DeprecationWarning if `dep_version` is provided, and raise the appropriate
  774. TypeError (function got multiple values for argument).
  775. Parameters
  776. ----------
  777. old_name : str
  778. Old name of parameter
  779. new_name : str
  780. New name of parameter
  781. dep_version : str, optional
  782. Version of SciPy in which old parameter was deprecated in the format
  783. 'X.Y.Z'. If supplied, the deprecation message will indicate that
  784. support for the old parameter will be removed in version 'X.Y+2.Z'
  785. Notes
  786. -----
  787. Untested with functions that accept *args. Probably won't work as written.
  788. """
  789. def decorator(fun):
  790. @functools.wraps(fun)
  791. def wrapper(*args, **kwargs):
  792. if old_name in kwargs:
  793. if dep_version:
  794. end_version = dep_version.split('.')
  795. end_version[1] = str(int(end_version[1]) + 2)
  796. end_version = '.'.join(end_version)
  797. message = (f"Use of keyword argument `{old_name}` is "
  798. f"deprecated and replaced by `{new_name}`. "
  799. f"Support for `{old_name}` will be removed "
  800. f"in SciPy {end_version}.")
  801. warnings.warn(message, DeprecationWarning, stacklevel=2)
  802. if new_name in kwargs:
  803. message = (f"{fun.__name__}() got multiple values for "
  804. f"argument now known as `{new_name}`")
  805. raise TypeError(message)
  806. kwargs[new_name] = kwargs.pop(old_name)
  807. return fun(*args, **kwargs)
  808. return wrapper
  809. return decorator
  810. def _rng_spawn(rng, n_children):
  811. # spawns independent RNGs from a parent RNG
  812. bg = rng._bit_generator
  813. ss = bg._seed_seq
  814. child_rngs = [np.random.Generator(type(bg)(child_ss))
  815. for child_ss in ss.spawn(n_children)]
  816. return child_rngs
  817. def _get_nan(*data, shape=(), xp=None):
  818. xp = array_namespace(*data) if xp is None else xp
  819. # Get NaN of appropriate dtype for data
  820. dtype = xp_result_type(*data, force_floating=True, xp=xp)
  821. device = xp_result_device(*data)
  822. res = xp.full(shape, xp.nan, dtype=dtype, device=device)
  823. if not shape:
  824. res = res[()]
  825. # whenever mdhaber/marray#89 is resolved, could just return `res`
  826. return res.data if is_marray(xp) else res
  827. def normalize_axis_index(axis, ndim):
  828. # Check if `axis` is in the correct range and normalize it
  829. if axis < -ndim or axis >= ndim:
  830. msg = f"axis {axis} is out of bounds for array of dimension {ndim}"
  831. raise AxisError(msg)
  832. if axis < 0:
  833. axis = axis + ndim
  834. return axis
  835. def _call_callback_maybe_halt(callback, res):
  836. """Call wrapped callback; return True if algorithm should stop.
  837. Parameters
  838. ----------
  839. callback : callable or None
  840. A user-provided callback wrapped with `_wrap_callback`
  841. res : OptimizeResult
  842. Information about the current iterate
  843. Returns
  844. -------
  845. halt : bool
  846. True if minimization should stop
  847. """
  848. if callback is None:
  849. return False
  850. try:
  851. callback(res)
  852. return False
  853. except StopIteration:
  854. callback.stop_iteration = True
  855. return True
  856. class _RichResult(dict):
  857. """ Container for multiple outputs with pretty-printing """
  858. def __getattr__(self, name):
  859. try:
  860. return self[name]
  861. except KeyError as e:
  862. raise AttributeError(name) from e
  863. __setattr__ = dict.__setitem__ # type: ignore[assignment]
  864. __delattr__ = dict.__delitem__ # type: ignore[assignment]
  865. def __repr__(self):
  866. order_keys = ['message', 'success', 'status', 'fun', 'funl', 'x', 'xl',
  867. 'col_ind', 'nit', 'lower', 'upper', 'eqlin', 'ineqlin',
  868. 'converged', 'flag', 'function_calls', 'iterations',
  869. 'root']
  870. order_keys = getattr(self, '_order_keys', order_keys)
  871. # 'slack', 'con' are redundant with residuals
  872. # 'crossover_nit' is probably not interesting to most users
  873. omit_keys = {'slack', 'con', 'crossover_nit', '_order_keys'}
  874. def key(item):
  875. try:
  876. return order_keys.index(item[0].lower())
  877. except ValueError: # item not in list
  878. return np.inf
  879. def omit_redundant(items):
  880. for item in items:
  881. if item[0] in omit_keys:
  882. continue
  883. yield item
  884. def item_sorter(d):
  885. return sorted(omit_redundant(d.items()), key=key)
  886. if self.keys():
  887. return _dict_formatter(self, sorter=item_sorter)
  888. else:
  889. return self.__class__.__name__ + "()"
  890. def __dir__(self):
  891. return list(self.keys())
  892. def _indenter(s, n=0):
  893. """
  894. Ensures that lines after the first are indented by the specified amount
  895. """
  896. split = s.split("\n")
  897. indent = " "*n
  898. return ("\n" + indent).join(split)
  899. def _float_formatter_10(x):
  900. """
  901. Returns a string representation of a float with exactly ten characters
  902. """
  903. if np.isposinf(x):
  904. return " inf"
  905. elif np.isneginf(x):
  906. return " -inf"
  907. elif np.isnan(x):
  908. return " nan"
  909. return np.format_float_scientific(x, precision=3, pad_left=2, unique=False)
  910. def _dict_formatter(d, n=0, mplus=1, sorter=None):
  911. """
  912. Pretty printer for dictionaries
  913. `n` keeps track of the starting indentation;
  914. lines are indented by this much after a line break.
  915. `mplus` is additional left padding applied to keys
  916. """
  917. if isinstance(d, dict):
  918. m = max(map(len, list(d.keys()))) + mplus # width to print keys
  919. s = '\n'.join([k.rjust(m) + ': ' + # right justified, width m
  920. _indenter(_dict_formatter(v, m+n+2, 0, sorter), m+2)
  921. for k, v in sorter(d)]) # +2 for ': '
  922. else:
  923. # By default, NumPy arrays print with linewidth=76. `n` is
  924. # the indent at which a line begins printing, so it is subtracted
  925. # from the default to avoid exceeding 76 characters total.
  926. # `edgeitems` is the number of elements to include before and after
  927. # ellipses when arrays are not shown in full.
  928. # `threshold` is the maximum number of elements for which an
  929. # array is shown in full.
  930. # These values tend to work well for use with OptimizeResult.
  931. with np.printoptions(linewidth=76-n, edgeitems=2, threshold=12,
  932. formatter={'float_kind': _float_formatter_10}):
  933. s = str(d)
  934. return s
  935. _batch_note = """
  936. The documentation is written assuming array arguments are of specified
  937. "core" shapes. However, array argument(s) of this function may have additional
  938. "batch" dimensions prepended to the core shape. In this case, the array is treated
  939. as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
  940. Note that calls with zero-size batches are unsupported and will raise a ``ValueError``.
  941. """
  942. def _apply_over_batch(*argdefs):
  943. """
  944. Factory for decorator that applies a function over batched arguments.
  945. Array arguments may have any number of core dimensions (typically 0,
  946. 1, or 2) and any broadcastable batch shapes. There may be any
  947. number of array outputs of any number of dimensions. Assumptions
  948. right now - which are satisfied by all functions of interest in `linalg` -
  949. are that all array inputs are consecutive keyword or positional arguments,
  950. and that the wrapped function returns either a single array or a tuple of
  951. arrays. It's only as general as it needs to be right now - it can be extended.
  952. Parameters
  953. ----------
  954. *argdefs : tuple of (str, int)
  955. Definitions of array arguments: the keyword name of the argument, and
  956. the number of core dimensions.
  957. Example:
  958. --------
  959. `linalg.eig` accepts two matrices as the first two arguments `a` and `b`, where
  960. `b` is optional, and returns one array or a tuple of arrays, depending on the
  961. values of other positional or keyword arguments. To generate a wrapper that applies
  962. the function over batches of `a` and optionally `b` :
  963. >>> _apply_over_batch(('a', 2), ('b', 2))
  964. """
  965. names, ndims = list(zip(*argdefs))
  966. n_arrays = len(names)
  967. def decorator(f):
  968. @functools.wraps(f)
  969. def wrapper(*args, **kwargs):
  970. args = list(args)
  971. # Ensure all arrays in `arrays`, other arguments in `other_args`/`kwargs`
  972. arrays, other_args = args[:n_arrays], args[n_arrays:]
  973. for i, name in enumerate(names):
  974. if name in kwargs:
  975. if i + 1 <= len(args):
  976. raise ValueError(f'{f.__name__}() got multiple values '
  977. f'for argument `{name}`.')
  978. else:
  979. arrays.append(kwargs.pop(name))
  980. xp = array_namespace(*arrays)
  981. # Determine core and batch shapes
  982. batch_shapes = []
  983. core_shapes = []
  984. for i, (array, ndim) in enumerate(zip(arrays, ndims)):
  985. array = None if array is None else xp.asarray(array)
  986. shape = () if array is None else array.shape
  987. if ndim == "1|2": # special case for `solve`, etc.
  988. ndim = 2 if array.ndim >= 2 else 1
  989. arrays[i] = array
  990. batch_shapes.append(shape[:-ndim] if ndim > 0 else shape)
  991. core_shapes.append(shape[-ndim:] if ndim > 0 else ())
  992. # Early exit if call is not batched
  993. if not any(batch_shapes):
  994. return f(*arrays, *other_args, **kwargs)
  995. # Determine broadcasted batch shape
  996. batch_shape = np.broadcast_shapes(*batch_shapes) # Gives OK error message
  997. # We can't support zero-size batches right now because without data with
  998. # which to call the function, the decorator doesn't even know the *number*
  999. # of outputs, let alone their core shapes or dtypes.
  1000. if math.prod(batch_shape) == 0:
  1001. message = f'`{f.__name__}` does not support zero-size batches.'
  1002. raise ValueError(message)
  1003. # Broadcast arrays to appropriate shape
  1004. for i, (array, core_shape) in enumerate(zip(arrays, core_shapes)):
  1005. if array is None:
  1006. continue
  1007. arrays[i] = xp.broadcast_to(array, batch_shape + core_shape)
  1008. # Main loop
  1009. results = []
  1010. for index in np.ndindex(batch_shape):
  1011. result = f(*((array[index] if array is not None else None)
  1012. for array in arrays), *other_args, **kwargs)
  1013. # Assume `result` is either a tuple or single array. This is easily
  1014. # generalized by allowing the contributor to pass an `unpack_result`
  1015. # callable to the decorator factory.
  1016. result = (result,) if not isinstance(result, tuple) else result
  1017. results.append(result)
  1018. results = list(zip(*results))
  1019. # Reshape results
  1020. for i, result in enumerate(results):
  1021. result = xp.stack(result)
  1022. core_shape = result.shape[1:]
  1023. results[i] = xp.reshape(result, batch_shape + core_shape)
  1024. # Assume `result` should be a single array if there is only one element or
  1025. # a `tuple` otherwise. This is easily generalized by allowing the
  1026. # contributor to pass an `pack_result` callable to the decorator factory.
  1027. return results[0] if len(results) == 1 else results
  1028. doc = FunctionDoc(wrapper)
  1029. doc['Extended Summary'].append(_batch_note.rstrip())
  1030. wrapper.__doc__ = str(doc).split("\n", 1)[1].lstrip(" \n") # remove signature
  1031. return wrapper
  1032. return decorator
  1033. def np_vecdot(x1, x2, /, *, axis=-1):
  1034. # `np.vecdot` has advantages (e.g. see gh-22462), so let's use it when
  1035. # available. As functions are translated to Array API, `np_vecdot` can be
  1036. # replaced with `xp.vecdot`.
  1037. if np.__version__ > "2.0":
  1038. return np.vecdot(x1, x2, axis=axis)
  1039. else:
  1040. # of course there are other fancy ways of doing this (e.g. `einsum`)
  1041. # but let's keep it simple since it's temporary
  1042. return np.sum(x1 * x2, axis=axis)
  1043. def _dedent_for_py313(s):
  1044. """Apply textwrap.dedent to s for Python versions 3.13 or later."""
  1045. return s if sys.version_info < (3, 13) else textwrap.dedent(s)
  1046. def broadcastable(shape_a: tuple[int, ...], shape_b: tuple[int, ...]) -> bool:
  1047. """Check if two shapes are broadcastable."""
  1048. return all(
  1049. (m == n) or (m == 1) or (n == 1) for m, n in zip(shape_a[::-1], shape_b[::-1])
  1050. )