_realtransforms_backend.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from scipy._lib._array_api import array_namespace
  2. import numpy as np
  3. from . import _pocketfft
  4. __all__ = ['dct', 'idct', 'dst', 'idst', 'dctn', 'idctn', 'dstn', 'idstn']
  5. def _execute(pocketfft_func, x, type, s, axes, norm,
  6. overwrite_x, workers, orthogonalize):
  7. xp = array_namespace(x)
  8. x = np.asarray(x)
  9. y = pocketfft_func(x, type, s, axes, norm,
  10. overwrite_x=overwrite_x, workers=workers,
  11. orthogonalize=orthogonalize)
  12. return xp.asarray(y)
  13. def dctn(x, type=2, s=None, axes=None, norm=None,
  14. overwrite_x=False, workers=None, *, orthogonalize=None):
  15. return _execute(_pocketfft.dctn, x, type, s, axes, norm,
  16. overwrite_x, workers, orthogonalize)
  17. def idctn(x, type=2, s=None, axes=None, norm=None,
  18. overwrite_x=False, workers=None, *, orthogonalize=None):
  19. return _execute(_pocketfft.idctn, x, type, s, axes, norm,
  20. overwrite_x, workers, orthogonalize)
  21. def dstn(x, type=2, s=None, axes=None, norm=None,
  22. overwrite_x=False, workers=None, orthogonalize=None):
  23. return _execute(_pocketfft.dstn, x, type, s, axes, norm,
  24. overwrite_x, workers, orthogonalize)
  25. def idstn(x, type=2, s=None, axes=None, norm=None,
  26. overwrite_x=False, workers=None, *, orthogonalize=None):
  27. return _execute(_pocketfft.idstn, x, type, s, axes, norm,
  28. overwrite_x, workers, orthogonalize)
  29. def dct(x, type=2, n=None, axis=-1, norm=None,
  30. overwrite_x=False, workers=None, orthogonalize=None):
  31. return _execute(_pocketfft.dct, x, type, n, axis, norm,
  32. overwrite_x, workers, orthogonalize)
  33. def idct(x, type=2, n=None, axis=-1, norm=None,
  34. overwrite_x=False, workers=None, orthogonalize=None):
  35. return _execute(_pocketfft.idct, x, type, n, axis, norm,
  36. overwrite_x, workers, orthogonalize)
  37. def dst(x, type=2, n=None, axis=-1, norm=None,
  38. overwrite_x=False, workers=None, orthogonalize=None):
  39. return _execute(_pocketfft.dst, x, type, n, axis, norm,
  40. overwrite_x, workers, orthogonalize)
  41. def idst(x, type=2, n=None, axis=-1, norm=None,
  42. overwrite_x=False, workers=None, orthogonalize=None):
  43. return _execute(_pocketfft.idst, x, type, n, axis, norm,
  44. overwrite_x, workers, orthogonalize)