basic.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. """
  2. Discrete Fourier Transforms - basic.py
  3. """
  4. import numpy as np
  5. import functools
  6. from . import pypocketfft as pfft
  7. from .helper import (_asfarray, _init_nd_shape_and_axes, _datacopied,
  8. _fix_shape, _fix_shape_1d, _normalization,
  9. _workers)
  10. def c2c(forward, x, n=None, axis=-1, norm=None, overwrite_x=False,
  11. workers=None, *, plan=None):
  12. """ Return discrete Fourier transform of real or complex sequence. """
  13. if plan is not None:
  14. raise NotImplementedError('Passing a precomputed plan is not yet '
  15. 'supported by scipy.fft functions')
  16. tmp = _asfarray(x)
  17. overwrite_x = overwrite_x or _datacopied(tmp, x)
  18. norm = _normalization(norm, forward)
  19. workers = _workers(workers)
  20. if n is not None:
  21. tmp, copied = _fix_shape_1d(tmp, n, axis)
  22. overwrite_x = overwrite_x or copied
  23. elif tmp.shape[axis] < 1:
  24. message = f"invalid number of data points ({tmp.shape[axis]}) specified"
  25. raise ValueError(message)
  26. out = (tmp if overwrite_x and tmp.dtype.kind == 'c' else None)
  27. return pfft.c2c(tmp, (axis,), forward, norm, out, workers)
  28. fft = functools.partial(c2c, True)
  29. fft.__name__ = 'fft'
  30. ifft = functools.partial(c2c, False)
  31. ifft.__name__ = 'ifft'
  32. def r2c(forward, x, n=None, axis=-1, norm=None, overwrite_x=False,
  33. workers=None, *, plan=None):
  34. """
  35. Discrete Fourier transform of a real sequence.
  36. """
  37. if plan is not None:
  38. raise NotImplementedError('Passing a precomputed plan is not yet '
  39. 'supported by scipy.fft functions')
  40. tmp = _asfarray(x)
  41. norm = _normalization(norm, forward)
  42. workers = _workers(workers)
  43. if not np.isrealobj(tmp):
  44. raise TypeError("x must be a real sequence")
  45. if n is not None:
  46. tmp, _ = _fix_shape_1d(tmp, n, axis)
  47. elif tmp.shape[axis] < 1:
  48. raise ValueError(f"invalid number of data points ({tmp.shape[axis]}) specified")
  49. # Note: overwrite_x is not utilised
  50. return pfft.r2c(tmp, (axis,), forward, norm, None, workers)
  51. rfft = functools.partial(r2c, True)
  52. rfft.__name__ = 'rfft'
  53. ihfft = functools.partial(r2c, False)
  54. ihfft.__name__ = 'ihfft'
  55. def c2r(forward, x, n=None, axis=-1, norm=None, overwrite_x=False,
  56. workers=None, *, plan=None):
  57. """
  58. Return inverse discrete Fourier transform of real sequence x.
  59. """
  60. if plan is not None:
  61. raise NotImplementedError('Passing a precomputed plan is not yet '
  62. 'supported by scipy.fft functions')
  63. tmp = _asfarray(x)
  64. norm = _normalization(norm, forward)
  65. workers = _workers(workers)
  66. # TODO: Optimize for hermitian and real?
  67. if np.isrealobj(tmp):
  68. tmp = tmp + 0.j
  69. # Last axis utilizes hermitian symmetry
  70. if n is None:
  71. n = (tmp.shape[axis] - 1) * 2
  72. if n < 1:
  73. raise ValueError(f"Invalid number of data points ({n}) specified")
  74. else:
  75. tmp, _ = _fix_shape_1d(tmp, (n//2) + 1, axis)
  76. # Note: overwrite_x is not utilized
  77. return pfft.c2r(tmp, (axis,), n, forward, norm, None, workers)
  78. hfft = functools.partial(c2r, True)
  79. hfft.__name__ = 'hfft'
  80. irfft = functools.partial(c2r, False)
  81. irfft.__name__ = 'irfft'
  82. def hfft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
  83. *, plan=None):
  84. """
  85. 2-D discrete Fourier transform of a Hermitian sequence
  86. """
  87. if plan is not None:
  88. raise NotImplementedError('Passing a precomputed plan is not yet '
  89. 'supported by scipy.fft functions')
  90. return hfftn(x, s, axes, norm, overwrite_x, workers)
  91. def ihfft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
  92. *, plan=None):
  93. """
  94. 2-D discrete inverse Fourier transform of a Hermitian sequence
  95. """
  96. if plan is not None:
  97. raise NotImplementedError('Passing a precomputed plan is not yet '
  98. 'supported by scipy.fft functions')
  99. return ihfftn(x, s, axes, norm, overwrite_x, workers)
  100. def c2cn(forward, x, s=None, axes=None, norm=None, overwrite_x=False,
  101. workers=None, *, plan=None):
  102. """
  103. Return multidimensional discrete Fourier transform.
  104. """
  105. if plan is not None:
  106. raise NotImplementedError('Passing a precomputed plan is not yet '
  107. 'supported by scipy.fft functions')
  108. tmp = _asfarray(x)
  109. shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
  110. overwrite_x = overwrite_x or _datacopied(tmp, x)
  111. workers = _workers(workers)
  112. if len(axes) == 0:
  113. return x
  114. tmp, copied = _fix_shape(tmp, shape, axes)
  115. overwrite_x = overwrite_x or copied
  116. norm = _normalization(norm, forward)
  117. out = (tmp if overwrite_x and tmp.dtype.kind == 'c' else None)
  118. return pfft.c2c(tmp, axes, forward, norm, out, workers)
  119. fftn = functools.partial(c2cn, True)
  120. fftn.__name__ = 'fftn'
  121. ifftn = functools.partial(c2cn, False)
  122. ifftn.__name__ = 'ifftn'
  123. def r2cn(forward, x, s=None, axes=None, norm=None, overwrite_x=False,
  124. workers=None, *, plan=None):
  125. """Return multidimensional discrete Fourier transform of real input"""
  126. if plan is not None:
  127. raise NotImplementedError('Passing a precomputed plan is not yet '
  128. 'supported by scipy.fft functions')
  129. tmp = _asfarray(x)
  130. if not np.isrealobj(tmp):
  131. raise TypeError("x must be a real sequence")
  132. shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
  133. tmp, _ = _fix_shape(tmp, shape, axes)
  134. norm = _normalization(norm, forward)
  135. workers = _workers(workers)
  136. if len(axes) == 0:
  137. raise ValueError("at least 1 axis must be transformed")
  138. # Note: overwrite_x is not utilized
  139. return pfft.r2c(tmp, axes, forward, norm, None, workers)
  140. rfftn = functools.partial(r2cn, True)
  141. rfftn.__name__ = 'rfftn'
  142. ihfftn = functools.partial(r2cn, False)
  143. ihfftn.__name__ = 'ihfftn'
  144. def c2rn(forward, x, s=None, axes=None, norm=None, overwrite_x=False,
  145. workers=None, *, plan=None):
  146. """Multidimensional inverse discrete fourier transform with real output"""
  147. if plan is not None:
  148. raise NotImplementedError('Passing a precomputed plan is not yet '
  149. 'supported by scipy.fft functions')
  150. tmp = _asfarray(x)
  151. # TODO: Optimize for hermitian and real?
  152. if np.isrealobj(tmp):
  153. tmp = tmp + 0.j
  154. noshape = s is None
  155. shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
  156. if len(axes) == 0:
  157. raise ValueError("at least 1 axis must be transformed")
  158. shape = list(shape)
  159. if noshape:
  160. shape[-1] = (x.shape[axes[-1]] - 1) * 2
  161. norm = _normalization(norm, forward)
  162. workers = _workers(workers)
  163. # Last axis utilizes hermitian symmetry
  164. lastsize = shape[-1]
  165. shape[-1] = (shape[-1] // 2) + 1
  166. tmp, _ = tuple(_fix_shape(tmp, shape, axes))
  167. # Note: overwrite_x is not utilized
  168. return pfft.c2r(tmp, axes, lastsize, forward, norm, None, workers)
  169. hfftn = functools.partial(c2rn, True)
  170. hfftn.__name__ = 'hfftn'
  171. irfftn = functools.partial(c2rn, False)
  172. irfftn.__name__ = 'irfftn'
  173. def r2r_fftpack(forward, x, n=None, axis=-1, norm=None, overwrite_x=False):
  174. """FFT of a real sequence, returning fftpack half complex format"""
  175. tmp = _asfarray(x)
  176. overwrite_x = overwrite_x or _datacopied(tmp, x)
  177. norm = _normalization(norm, forward)
  178. workers = _workers(None)
  179. if tmp.dtype.kind == 'c':
  180. raise TypeError('x must be a real sequence')
  181. if n is not None:
  182. tmp, copied = _fix_shape_1d(tmp, n, axis)
  183. overwrite_x = overwrite_x or copied
  184. elif tmp.shape[axis] < 1:
  185. raise ValueError(f"invalid number of data points ({tmp.shape[axis]}) specified")
  186. out = (tmp if overwrite_x else None)
  187. return pfft.r2r_fftpack(tmp, (axis,), forward, forward, norm, out, workers)
  188. rfft_fftpack = functools.partial(r2r_fftpack, True)
  189. rfft_fftpack.__name__ = 'rfft_fftpack'
  190. irfft_fftpack = functools.partial(r2r_fftpack, False)
  191. irfft_fftpack.__name__ = 'irfft_fftpack'