_delegators.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. """Delegators for alternative backends in scipy.signal.
  2. The signature of `func_signature` must match the signature of signal.func.
  3. The job of a `func_signature` is to know which arguments of `signal.func`
  4. are arrays.
  5. * signatures are generated by
  6. --------------
  7. import inspect
  8. from scipy import signal
  9. names = [x for x in dir(signal) if not x.startswith('_')]
  10. objs = [getattr(signal, name) for name in names]
  11. funcs = [obj for obj in objs if inspect.isroutine(obj)]
  12. for func in funcs:
  13. try:
  14. sig = inspect.signature(func)
  15. except ValueError:
  16. sig = "( FIXME )"
  17. print(f"def {func.__name__}_signature{sig}:\n\treturn array_namespace(...
  18. )\n\n")
  19. ---------------
  20. * which arguments to delegate on: manually trawled the documentation for
  21. array-like and array arguments
  22. """
  23. import numpy as np
  24. from scipy._lib._array_api import array_namespace, np_compat, is_jax
  25. def _skip_if_lti(arg):
  26. """Handle `system` arg overloads.
  27. ATM, only pass tuples through. Consider updating when cupyx.lti class
  28. is supported.
  29. """
  30. if isinstance(arg, tuple):
  31. return arg
  32. else:
  33. return (None,)
  34. def _skip_if_str_or_tuple(window):
  35. """Handle `window` being a str or a tuple or an array-like.
  36. """
  37. if isinstance(window, str) or isinstance(window, tuple) or callable(window):
  38. return None
  39. else:
  40. return window
  41. def _skip_if_poly1d(arg):
  42. return None if isinstance(arg, np.poly1d) else arg
  43. ###################
  44. def abcd_normalize_signature(A=None, B=None, C=None, D=None):
  45. return array_namespace(A, B, C, D)
  46. def argrelextrema_signature(data, *args, **kwds):
  47. return array_namespace(data)
  48. argrelmax_signature = argrelextrema_signature
  49. argrelmin_signature = argrelextrema_signature
  50. def band_stop_obj_signature(wp, ind, passb, stopb, gpass, gstop, type):
  51. return array_namespace(passb, stopb)
  52. def bessel_signature(N, Wn, *args, **kwds):
  53. return array_namespace(Wn)
  54. butter_signature = bessel_signature
  55. def cheby2_signature(N, rs, Wn, *args, **kwds):
  56. return array_namespace(Wn)
  57. def cheby1_signature(N, rp, Wn, *args, **kwds):
  58. return array_namespace(Wn)
  59. def ellip_signature(N, rp, rs, Wn, *args, **kwds):
  60. return array_namespace(Wn)
  61. ########################## XXX: no arrays in, arrays out
  62. def besselap_signature(N, norm='phase', *, xp=None, device=None):
  63. return np if xp is None else xp
  64. def buttap_signature(N, *, xp=None, device=None):
  65. return np if xp is None else xp
  66. def cheb1ap_signature(N, rp, *, xp=None, device=None):
  67. return np if xp is None else xp
  68. def cheb2ap_signature(N, rs, *, xp=None, device=None):
  69. return np if xp is None else xp
  70. def ellipap_signature(N, rp, rs, *, xp=None, device=None):
  71. return np if xp is None else xp
  72. def correlation_lags_signature(in1_len, in2_len, mode='full'):
  73. return np
  74. def czt_points_signature(m, w=None, a=(1+0j)):
  75. return np
  76. def gammatone_signature(
  77. freq, ftype, order=None, numtaps=None, fs=None, *, xp=None, device=None
  78. ):
  79. return np_compat if xp is None else xp
  80. def iircomb_signature(
  81. w0, Q, ftype='notch', fs=2.0, *, pass_zero=False, xp=None, device=None
  82. ):
  83. return np_compat if xp is None else xp
  84. def iirnotch_signature(w0, Q, fs=2.0, *, xp=None, device=None):
  85. return np if xp is None else xp
  86. iirpeak_signature = iirnotch_signature
  87. def savgol_coeffs_signature(
  88. window_length, polyorder, deriv=0, delta=1.0, pos=None, use='conv',
  89. *, xp=None, device=None
  90. ):
  91. return np if xp is None else xp
  92. def unit_impulse_signature(shape, idx=None, dtype=float):
  93. return np
  94. ############################
  95. def buttord_signature(wp, ws, gpass, gstop, analog=False, fs=None):
  96. return array_namespace(wp, ws)
  97. cheb1ord_signature = buttord_signature
  98. cheb2ord_signature = buttord_signature
  99. ellipord_signature = buttord_signature
  100. ########### NB: scalars in, scalars out
  101. def kaiser_atten_signature(numtaps, width):
  102. return np
  103. def kaiser_beta_signature(a):
  104. return np
  105. def kaiserord_signature(ripple, width):
  106. return np
  107. def get_window_signature(window, Nx, fftbins=True, *, xp=None, device=None):
  108. return np if xp is None else xp
  109. #################################
  110. def bode_signature(system, w=None, n=100):
  111. return array_namespace(*_skip_if_lti(system), w)
  112. dbode_signature = bode_signature
  113. def freqresp_signature(system, w=None, n=10000):
  114. return array_namespace(*_skip_if_lti(system), w)
  115. dfreqresp_signature = freqresp_signature
  116. def impulse_signature(system, X0=None, T=None, N=None):
  117. return array_namespace(*_skip_if_lti(system), X0, T)
  118. def dimpulse_signature(system, x0=None, t=None, n=None):
  119. return array_namespace(*_skip_if_lti(system), x0, t)
  120. def lsim_signature(system, U, T, X0=None, interp=True):
  121. return array_namespace(*_skip_if_lti(system), U, T, X0)
  122. def dlsim_signature(system, u, t=None, x0=None):
  123. return array_namespace(*_skip_if_lti(system), u, t, x0)
  124. def step_signature(system, X0=None, T=None, N=None):
  125. return array_namespace(*_skip_if_lti(system), X0, T)
  126. def dstep_signature(system, x0=None, t=None, n=None):
  127. return array_namespace(*_skip_if_lti(system), x0, t)
  128. def cont2discrete_signature(system, dt, method='zoh', alpha=None):
  129. return array_namespace(*_skip_if_lti(system))
  130. def bilinear_signature(b, a, fs=1.0):
  131. return array_namespace(b, a)
  132. def bilinear_zpk_signature(z, p, k, fs):
  133. return array_namespace(z, p)
  134. def chirp_signature(t,*args, **kwds):
  135. return array_namespace(t)
  136. ############## XXX: array-likes in, str out
  137. def choose_conv_method_signature(in1, in2, *args, **kwds):
  138. return array_namespace(in1, in2)
  139. ############################################
  140. def convolve_signature(in1, in2, *args, **kwds):
  141. return array_namespace(in1, in2)
  142. fftconvolve_signature = convolve_signature
  143. oaconvolve_signature = convolve_signature
  144. correlate_signature = convolve_signature
  145. correlate_signature = convolve_signature
  146. convolve2d_signature = convolve_signature
  147. correlate2d_signature = convolve_signature
  148. def coherence_signature(x, y, fs=1.0, window='hann_periodic', *args, **kwds):
  149. return array_namespace(x, y, _skip_if_str_or_tuple(window))
  150. def csd_signature(x, y, fs=1.0, window='hann_periodic', *args, **kwds):
  151. return array_namespace(x, y, _skip_if_str_or_tuple(window))
  152. def periodogram_signature(x, fs=1.0, window='boxcar', *args, **kwds):
  153. return array_namespace(x, _skip_if_str_or_tuple(window))
  154. def welch_signature(x, fs=1.0, window='hann_periodic', *args, **kwds):
  155. return array_namespace(x, _skip_if_str_or_tuple(window))
  156. def spectrogram_signature(x, fs=1.0, window=('tukey_periodic', 0.25), *args, **kwds):
  157. return array_namespace(x, _skip_if_str_or_tuple(window))
  158. def stft_signature(x, fs=1.0, window='hann_periodic', *args, **kwds):
  159. return array_namespace(x, _skip_if_str_or_tuple(window))
  160. def istft_signature(Zxx, fs=1.0, window='hann_periodic', *args, **kwds):
  161. return array_namespace(Zxx, _skip_if_str_or_tuple(window))
  162. def resample_signature(x, num, t=None, axis=0, window=None, domain='time'):
  163. return array_namespace(x, t, _skip_if_str_or_tuple(window))
  164. def resample_poly_signature(x, up, down, axis=0, window=('kaiser', 5.0), *args, **kwds):
  165. return array_namespace(x, _skip_if_str_or_tuple(window))
  166. def check_COLA_signature(window, nperseg, noverlap, tol=1e-10):
  167. return array_namespace(_skip_if_str_or_tuple(window))
  168. def check_NOLA_signature(window, nperseg, noverlap, tol=1e-10):
  169. return array_namespace(_skip_if_str_or_tuple(window))
  170. def czt_signature(x, *args, **kwds):
  171. return array_namespace(x)
  172. decimate_signature = czt_signature
  173. gauss_spline_signature = czt_signature
  174. def deconvolve_signature(signal, divisor):
  175. return array_namespace(signal, divisor)
  176. def detrend_signature(data, axis=1, type='linear', bp=0, *args, **kwds):
  177. xp = array_namespace(data)
  178. # JAX doesn't accept JAX arrays for bp, only ints, lists and NumPy
  179. # arrays.
  180. return xp if is_jax(xp) else array_namespace(data, bp)
  181. def filtfilt_signature(b, a, x, *args, **kwds):
  182. return array_namespace(b, a, x)
  183. def lfilter_signature(b, a, x, axis=-1, zi=None):
  184. return array_namespace(b, a, x, zi)
  185. def envelope_signature(z, *args, **kwds):
  186. return array_namespace(z)
  187. def find_peaks_signature(
  188. x, height=None, threshold=None, distance=None, prominence=None, width=None,
  189. wlen=None, rel_height=0.5, plateau_size=None
  190. ):
  191. # TODO: fix me - `prominence` is not necessarily an array.
  192. # return array_namespace(x, height, threshold, prominence, width, plateau_size)
  193. # See https://github.com/scipy/scipy/pull/22644#issuecomment-3568443768. For now:
  194. return np_compat
  195. def find_peaks_cwt_signature(
  196. vector, widths, wavelet=None, max_distances=None, *args, **kwds
  197. ):
  198. return array_namespace(vector, widths, max_distances)
  199. def findfreqs_signature(num, den, N, kind='ba'):
  200. return array_namespace(num, den)
  201. def firls_signature(numtaps, bands, desired, *, weight=None, fs=None):
  202. return array_namespace(bands, desired, weight)
  203. def firwin_signature(numtaps, cutoff, *args, **kwds):
  204. if isinstance(cutoff, int | float):
  205. xp = np_compat
  206. else:
  207. xp = array_namespace(cutoff)
  208. return xp
  209. def firwin2_signature(numtaps, freq, gain, *args, **kwds):
  210. return array_namespace(freq, gain)
  211. def freqs_zpk_signature(z, p, k, worN=200, *args, **kwds):
  212. return array_namespace(z, p, worN)
  213. freqz_zpk_signature = freqs_zpk_signature
  214. def freqs_signature(b, a, worN=200, *args, **kwds):
  215. return array_namespace(b, a, worN)
  216. def freqz_signature(b, a=1, worN=512, *args, **kwds):
  217. # differs from freqs: `a` has a default value
  218. return array_namespace(b, a, worN)
  219. def freqz_sos_signature(sos, worN=512, *args, **kwds):
  220. return array_namespace(sos, worN)
  221. sosfreqz_signature = freqz_sos_signature
  222. def gausspulse_signature(t, *args, **kwds):
  223. arr_t = None if isinstance(t, str) else t
  224. return array_namespace(arr_t)
  225. def group_delay_signature(system, w=512, whole=False, fs=6.283185307179586):
  226. return array_namespace(*system, w)
  227. def hilbert_signature(x, *args, **kwds):
  228. return array_namespace(x)
  229. hilbert2_signature = hilbert_signature
  230. def iirdesign_signature(wp, ws, *args, **kwds):
  231. return array_namespace(wp, ws)
  232. def iirfilter_signature(N, Wn, *args, **kwds):
  233. return array_namespace(Wn)
  234. def invres_signature(r, p, k, tol=0.001, rtype='avg'):
  235. return array_namespace(r, p, k)
  236. invresz_signature = invres_signature
  237. ############################### XXX: excluded, blacklisted on CuPy (mismatched API)
  238. def lfilter_zi_signature(b, a):
  239. return array_namespace(b, a)
  240. def sosfilt_zi_signature(sos):
  241. return array_namespace(sos)
  242. # needs to be blacklisted on CuPy (is not implemented)
  243. def remez_signature(numtaps, bands, desired, *, weight=None, **kwds):
  244. return array_namespace(bands, desired, weight)
  245. #############################################
  246. def lfiltic_signature(b, a, y, x=None):
  247. return array_namespace(b, a, y, x)
  248. def lombscargle_signature(
  249. x, y, freqs, precenter=False, normalize=False, *,
  250. weights=None, floating_mean=False
  251. ):
  252. return array_namespace(x, y, freqs, weights)
  253. def lp2bp_signature(b, a, *args, **kwds):
  254. return array_namespace(b, a)
  255. lp2bs_signature = lp2bp_signature
  256. lp2hp_signature = lp2bp_signature
  257. lp2lp_signature = lp2bp_signature
  258. tf2zpk_signature = lp2bp_signature
  259. tf2sos_signature = lp2bp_signature
  260. normalize_signature = lp2bp_signature
  261. residue_signature = lp2bp_signature
  262. residuez_signature = residue_signature
  263. def lp2bp_zpk_signature(z, p, k, *args, **kwds):
  264. return array_namespace(z, p)
  265. lp2bs_zpk_signature = lp2bp_zpk_signature
  266. lp2hp_zpk_signature = lp2bs_zpk_signature
  267. lp2lp_zpk_signature = lp2bs_zpk_signature
  268. def zpk2sos_signature(z, p, k, *args, **kwds):
  269. return array_namespace(z, p)
  270. zpk2ss_signature = zpk2sos_signature
  271. zpk2tf_signature = zpk2sos_signature
  272. def max_len_seq_signature(nbits, state=None, length=None, taps=None):
  273. return array_namespace(state, taps)
  274. def medfilt_signature(volume, kernel_size=None):
  275. return array_namespace(volume)
  276. def medfilt2d_signature(input, kernel_size=3):
  277. return array_namespace(input)
  278. def minimum_phase_signature(h, *args, **kwds):
  279. return array_namespace(h)
  280. def order_filter_signature(a, domain, rank):
  281. return array_namespace(a, domain)
  282. def peak_prominences_signature(x, peaks, *args, **kwds):
  283. return array_namespace(x, peaks)
  284. peak_widths_signature = peak_prominences_signature
  285. def place_poles_signature(A, B, poles, method='YT', rtol=0.001, maxiter=30):
  286. return array_namespace(A, B, poles)
  287. def savgol_filter_signature(x, *args, **kwds):
  288. return array_namespace(x)
  289. def sawtooth_signature(t, width=1):
  290. return array_namespace(t)
  291. def sepfir2d_signature(input, hrow, hcol):
  292. return array_namespace(input, hrow, hcol)
  293. def sos2tf_signature(sos):
  294. return array_namespace(sos)
  295. sos2zpk_signature = sos2tf_signature
  296. def sosfilt_signature(sos, x, axis=-1, zi=None):
  297. return array_namespace(sos, x, zi)
  298. def sosfiltfilt_signature(sos, x, *args, **kwds):
  299. return array_namespace(sos, x)
  300. def spline_filter_signature(Iin, lmbda=5.0):
  301. return array_namespace(Iin)
  302. def square_signature(t, duty=0.5):
  303. return array_namespace(t)
  304. def ss2tf_signature(A, B, C, D, input=0):
  305. return array_namespace(A, B, C, D)
  306. ss2zpk_signature = ss2tf_signature
  307. def sweep_poly_signature(t, poly, phi=0):
  308. return array_namespace(t, _skip_if_poly1d(poly))
  309. def symiirorder1_signature(signal, c0, z1, precision=-1.0):
  310. return array_namespace(signal)
  311. def symiirorder2_signature(input, r, omega, precision=-1.0):
  312. return array_namespace(input, r, omega)
  313. def cspline1d_signature(signal, *args, **kwds):
  314. return array_namespace(signal)
  315. qspline1d_signature = cspline1d_signature
  316. cspline2d_signature = cspline1d_signature
  317. qspline2d_signature = qspline1d_signature
  318. def cspline1d_eval_signature(cj, newx, *args, **kwds):
  319. return array_namespace(cj, newx)
  320. qspline1d_eval_signature = cspline1d_eval_signature
  321. def tf2ss_signature(num, den):
  322. return array_namespace(num, den)
  323. def unique_roots_signature(p, tol=0.001, rtype='min'):
  324. return array_namespace(p)
  325. def upfirdn_signature(h, x, up=1, down=1, axis=-1, mode='constant', cval=0):
  326. return array_namespace(h, x)
  327. def vectorstrength_signature(events, period):
  328. return array_namespace(events, period)
  329. def wiener_signature(im, mysize=None, noise=None):
  330. return array_namespace(im)
  331. def zoom_fft_signature(x, fn, m=None, *, fs=2, endpoint=False, axis=-1):
  332. return array_namespace(x, fn)