helper.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. from numbers import Number
  2. import operator
  3. import os
  4. import threading
  5. import contextlib
  6. import numpy as np
  7. from scipy._lib._util import copy_if_needed
  8. from scipy._lib._array_api import xp_capabilities
  9. # good_size is exposed (and used) from this import
  10. from .pypocketfft import good_size, prev_good_size
  11. __all__ = ['good_size', 'prev_good_size', 'set_workers', 'get_workers']
  12. _config = threading.local()
  13. _cpu_count = os.cpu_count()
  14. def _iterable_of_int(x, name=None):
  15. """Convert ``x`` to an iterable sequence of int
  16. Parameters
  17. ----------
  18. x : value, or sequence of values, convertible to int
  19. name : str, optional
  20. Name of the argument being converted, only used in the error message
  21. Returns
  22. -------
  23. y : ``List[int]``
  24. """
  25. if isinstance(x, Number):
  26. x = (x,)
  27. try:
  28. x = [operator.index(a) for a in x]
  29. except TypeError as e:
  30. name = name or "value"
  31. raise ValueError(f"{name} must be a scalar or iterable of integers") from e
  32. return x
  33. def _init_nd_shape_and_axes(x, shape, axes):
  34. """
  35. Handle shape and axes arguments for N-D transforms.
  36. Returns the shape and axes in a standard form, taking into account negative
  37. values and checking for various potential errors.
  38. Parameters
  39. ----------
  40. x : ndarray
  41. The input array.
  42. shape : int or array_like of ints or None
  43. The shape of the result. If both `shape` and `axes` (see below) are
  44. None, `shape` is ``x.shape``; if `shape` is None but `axes` is
  45. not None, then `shape` is ``numpy.take(x.shape, axes, axis=0)``.
  46. If `shape` is -1, the size of the corresponding dimension of `x` is
  47. used.
  48. axes : int or array_like of ints or None
  49. Axes along which the calculation is computed.
  50. The default is over all axes.
  51. Negative indices are automatically converted to their positive
  52. counterparts.
  53. Returns
  54. -------
  55. shape : tuple
  56. The shape of the result as a tuple of integers.
  57. axes : list
  58. Axes along which the calculation is computed, as a list of integers.
  59. """
  60. noshape = shape is None
  61. noaxes = axes is None
  62. if not noaxes:
  63. axes = _iterable_of_int(axes, 'axes')
  64. axes = [a + x.ndim if a < 0 else a for a in axes]
  65. if any(a >= x.ndim or a < 0 for a in axes):
  66. raise ValueError("axes exceeds dimensionality of input")
  67. if len(set(axes)) != len(axes):
  68. raise ValueError("all axes must be unique")
  69. if not noshape:
  70. shape = _iterable_of_int(shape, 'shape')
  71. if axes and len(axes) != len(shape):
  72. raise ValueError("when given, axes and shape arguments"
  73. " have to be of the same length")
  74. if noaxes:
  75. if len(shape) > x.ndim:
  76. raise ValueError("shape requires more axes than are present")
  77. axes = range(x.ndim - len(shape), x.ndim)
  78. shape = [x.shape[a] if s == -1 else s for s, a in zip(shape, axes)]
  79. elif noaxes:
  80. shape = list(x.shape)
  81. axes = range(x.ndim)
  82. else:
  83. shape = [x.shape[a] for a in axes]
  84. if any(s < 1 for s in shape):
  85. raise ValueError(
  86. f"invalid number of data points ({shape}) specified")
  87. return tuple(shape), list(axes)
  88. def _asfarray(x):
  89. """
  90. Convert to array with floating or complex dtype.
  91. float16 values are also promoted to float32.
  92. """
  93. if not hasattr(x, "dtype"):
  94. x = np.asarray(x)
  95. if x.dtype == np.float16:
  96. return np.asarray(x, np.float32)
  97. elif x.dtype.kind not in 'fc':
  98. return np.asarray(x, np.float64)
  99. # Require native byte order
  100. dtype = x.dtype.newbyteorder('=')
  101. # Always align input
  102. copy = True if not x.flags['ALIGNED'] else copy_if_needed
  103. return np.array(x, dtype=dtype, copy=copy)
  104. def _datacopied(arr, original):
  105. """
  106. Strict check for `arr` not sharing any data with `original`,
  107. under the assumption that arr = asarray(original)
  108. """
  109. if arr is original:
  110. return False
  111. if not isinstance(original, np.ndarray) and hasattr(original, '__array__'):
  112. return False
  113. return arr.base is None
  114. def _fix_shape(x, shape, axes):
  115. """Internal auxiliary function for _raw_fft, _raw_fftnd."""
  116. must_copy = False
  117. # Build an nd slice with the dimensions to be read from x
  118. index = [slice(None)]*x.ndim
  119. for n, ax in zip(shape, axes):
  120. if x.shape[ax] >= n:
  121. index[ax] = slice(0, n)
  122. else:
  123. index[ax] = slice(0, x.shape[ax])
  124. must_copy = True
  125. index = tuple(index)
  126. if not must_copy:
  127. return x[index], False
  128. s = list(x.shape)
  129. for n, axis in zip(shape, axes):
  130. s[axis] = n
  131. z = np.zeros(s, x.dtype)
  132. z[index] = x[index]
  133. return z, True
  134. def _fix_shape_1d(x, n, axis):
  135. if n < 1:
  136. raise ValueError(
  137. f"invalid number of data points ({n}) specified")
  138. return _fix_shape(x, (n,), (axis,))
  139. _NORM_MAP = {None: 0, 'backward': 0, 'ortho': 1, 'forward': 2}
  140. def _normalization(norm, forward):
  141. """Returns the pypocketfft normalization mode from the norm argument"""
  142. try:
  143. inorm = _NORM_MAP[norm]
  144. return inorm if forward else (2 - inorm)
  145. except KeyError:
  146. raise ValueError(
  147. f'Invalid norm value {norm!r}, should '
  148. 'be "backward", "ortho" or "forward"') from None
  149. def _workers(workers):
  150. if workers is None:
  151. return getattr(_config, 'default_workers', 1)
  152. if workers < 0:
  153. if workers >= -_cpu_count:
  154. workers += 1 + _cpu_count
  155. else:
  156. raise ValueError(f"workers value out of range; got {workers}, must not be"
  157. f" less than {-_cpu_count}")
  158. elif workers == 0:
  159. raise ValueError("workers must not be zero")
  160. return workers
  161. @xp_capabilities(out_of_scope=True)
  162. @contextlib.contextmanager
  163. def set_workers(workers):
  164. """Context manager for the default number of workers used in `scipy.fft`
  165. Parameters
  166. ----------
  167. workers : int
  168. The default number of workers to use
  169. Examples
  170. --------
  171. >>> import numpy as np
  172. >>> from scipy import fft, signal
  173. >>> rng = np.random.default_rng()
  174. >>> x = rng.standard_normal((128, 64))
  175. >>> with fft.set_workers(4):
  176. ... y = signal.fftconvolve(x, x)
  177. """
  178. old_workers = get_workers()
  179. _config.default_workers = _workers(operator.index(workers))
  180. try:
  181. yield
  182. finally:
  183. _config.default_workers = old_workers
  184. @xp_capabilities(out_of_scope=True)
  185. def get_workers():
  186. """Returns the default number of workers within the current context
  187. Examples
  188. --------
  189. >>> from scipy import fft
  190. >>> fft.get_workers()
  191. 1
  192. >>> with fft.set_workers(4):
  193. ... fft.get_workers()
  194. 4
  195. """
  196. return getattr(_config, 'default_workers', 1)