realtransforms.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import numpy as np
  2. from . import pypocketfft as pfft
  3. from .helper import (_asfarray, _init_nd_shape_and_axes, _datacopied,
  4. _fix_shape, _fix_shape_1d, _normalization, _workers)
  5. import functools
  6. def _r2r(forward, transform, x, type=2, n=None, axis=-1, norm=None,
  7. overwrite_x=False, workers=None, orthogonalize=None):
  8. """Forward or backward 1-D DCT/DST
  9. Parameters
  10. ----------
  11. forward : bool
  12. Transform direction (determines type and normalisation)
  13. transform : {pypocketfft.dct, pypocketfft.dst}
  14. The transform to perform
  15. """
  16. tmp = _asfarray(x)
  17. overwrite_x = overwrite_x or _datacopied(tmp, x)
  18. norm = _normalization(norm, forward)
  19. workers = _workers(workers)
  20. if not forward:
  21. if type == 2:
  22. type = 3
  23. elif type == 3:
  24. type = 2
  25. if n is not None:
  26. tmp, copied = _fix_shape_1d(tmp, n, axis)
  27. overwrite_x = overwrite_x or copied
  28. elif tmp.shape[axis] < 1:
  29. raise ValueError(f"invalid number of data points ({tmp.shape[axis]}) specified")
  30. out = (tmp if overwrite_x else None)
  31. # For complex input, transform real and imaginary components separably
  32. if np.iscomplexobj(x):
  33. out = np.empty_like(tmp) if out is None else out
  34. transform(tmp.real, type, (axis,), norm, out.real, workers)
  35. transform(tmp.imag, type, (axis,), norm, out.imag, workers)
  36. return out
  37. return transform(tmp, type, (axis,), norm, out, workers, orthogonalize)
  38. dct = functools.partial(_r2r, True, pfft.dct)
  39. dct.__name__ = 'dct'
  40. idct = functools.partial(_r2r, False, pfft.dct)
  41. idct.__name__ = 'idct'
  42. dst = functools.partial(_r2r, True, pfft.dst)
  43. dst.__name__ = 'dst'
  44. idst = functools.partial(_r2r, False, pfft.dst)
  45. idst.__name__ = 'idst'
  46. def _r2rn(forward, transform, x, type=2, s=None, axes=None, norm=None,
  47. overwrite_x=False, workers=None, orthogonalize=None):
  48. """Forward or backward nd DCT/DST
  49. Parameters
  50. ----------
  51. forward : bool
  52. Transform direction (determines type and normalisation)
  53. transform : {pypocketfft.dct, pypocketfft.dst}
  54. The transform to perform
  55. """
  56. tmp = _asfarray(x)
  57. shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
  58. overwrite_x = overwrite_x or _datacopied(tmp, x)
  59. if len(axes) == 0:
  60. return x
  61. tmp, copied = _fix_shape(tmp, shape, axes)
  62. overwrite_x = overwrite_x or copied
  63. if not forward:
  64. if type == 2:
  65. type = 3
  66. elif type == 3:
  67. type = 2
  68. norm = _normalization(norm, forward)
  69. workers = _workers(workers)
  70. out = (tmp if overwrite_x else None)
  71. # For complex input, transform real and imaginary components separably
  72. if np.iscomplexobj(x):
  73. out = np.empty_like(tmp) if out is None else out
  74. transform(tmp.real, type, axes, norm, out.real, workers)
  75. transform(tmp.imag, type, axes, norm, out.imag, workers)
  76. return out
  77. return transform(tmp, type, axes, norm, out, workers, orthogonalize)
  78. dctn = functools.partial(_r2rn, True, pfft.dct)
  79. dctn.__name__ = 'dctn'
  80. idctn = functools.partial(_r2rn, False, pfft.dct)
  81. idctn.__name__ = 'idctn'
  82. dstn = functools.partial(_r2rn, True, pfft.dst)
  83. dstn.__name__ = 'dstn'
  84. idstn = functools.partial(_r2rn, False, pfft.dst)
  85. idstn.__name__ = 'idstn'