_basic_backend.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. from scipy._lib._array_api import (
  2. array_namespace, is_numpy, xp_unsupported_param_msg, is_complex, xp_float_to_complex
  3. )
  4. from . import _pocketfft
  5. import numpy as np
  6. def _validate_fft_args(workers, plan, norm):
  7. if workers is not None:
  8. raise ValueError(xp_unsupported_param_msg("workers"))
  9. if plan is not None:
  10. raise ValueError(xp_unsupported_param_msg("plan"))
  11. if norm is None:
  12. norm = 'backward'
  13. return norm
  14. # these functions expect complex input in the fft standard extension
  15. complex_funcs = {'fft', 'ifft', 'fftn', 'ifftn', 'hfft', 'irfft', 'irfftn'}
  16. # pocketfft is used whenever SCIPY_ARRAY_API is not set,
  17. # or x is a NumPy array or array-like.
  18. # When SCIPY_ARRAY_API is set, we try to use xp.fft for CuPy arrays,
  19. # PyTorch arrays and other array API standard supporting objects.
  20. # If xp.fft does not exist, we attempt to convert to np and back to use pocketfft.
  21. def _execute_1D(func_str, pocketfft_func, x, n, axis, norm, overwrite_x, workers, plan):
  22. xp = array_namespace(x)
  23. if is_numpy(xp):
  24. x = np.asarray(x)
  25. return pocketfft_func(x, n=n, axis=axis, norm=norm,
  26. overwrite_x=overwrite_x, workers=workers, plan=plan)
  27. norm = _validate_fft_args(workers, plan, norm)
  28. if hasattr(xp, 'fft'):
  29. xp_func = getattr(xp.fft, func_str)
  30. if func_str in complex_funcs:
  31. try:
  32. res = xp_func(x, n=n, axis=axis, norm=norm)
  33. except: # backends may require complex input # noqa: E722
  34. x = xp_float_to_complex(x, xp)
  35. res = xp_func(x, n=n, axis=axis, norm=norm)
  36. return res
  37. return xp_func(x, n=n, axis=axis, norm=norm)
  38. x = np.asarray(x)
  39. y = pocketfft_func(x, n=n, axis=axis, norm=norm)
  40. return xp.asarray(y)
  41. def _execute_nD(func_str, pocketfft_func, x, s, axes, norm, overwrite_x, workers, plan):
  42. xp = array_namespace(x)
  43. if is_numpy(xp):
  44. x = np.asarray(x)
  45. return pocketfft_func(x, s=s, axes=axes, norm=norm,
  46. overwrite_x=overwrite_x, workers=workers, plan=plan)
  47. norm = _validate_fft_args(workers, plan, norm)
  48. if hasattr(xp, 'fft'):
  49. xp_func = getattr(xp.fft, func_str)
  50. if func_str in complex_funcs:
  51. try:
  52. res = xp_func(x, s=s, axes=axes, norm=norm)
  53. except: # backends may require complex input # noqa: E722
  54. x = xp_float_to_complex(x, xp)
  55. res = xp_func(x, s=s, axes=axes, norm=norm)
  56. return res
  57. return xp_func(x, s=s, axes=axes, norm=norm)
  58. x = np.asarray(x)
  59. y = pocketfft_func(x, s=s, axes=axes, norm=norm)
  60. return xp.asarray(y)
  61. def fft(x, n=None, axis=-1, norm=None,
  62. overwrite_x=False, workers=None, *, plan=None):
  63. return _execute_1D('fft', _pocketfft.fft, x, n=n, axis=axis, norm=norm,
  64. overwrite_x=overwrite_x, workers=workers, plan=plan)
  65. def ifft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, *,
  66. plan=None):
  67. return _execute_1D('ifft', _pocketfft.ifft, x, n=n, axis=axis, norm=norm,
  68. overwrite_x=overwrite_x, workers=workers, plan=plan)
  69. def rfft(x, n=None, axis=-1, norm=None,
  70. overwrite_x=False, workers=None, *, plan=None):
  71. return _execute_1D('rfft', _pocketfft.rfft, x, n=n, axis=axis, norm=norm,
  72. overwrite_x=overwrite_x, workers=workers, plan=plan)
  73. def irfft(x, n=None, axis=-1, norm=None,
  74. overwrite_x=False, workers=None, *, plan=None):
  75. return _execute_1D('irfft', _pocketfft.irfft, x, n=n, axis=axis, norm=norm,
  76. overwrite_x=overwrite_x, workers=workers, plan=plan)
  77. def hfft(x, n=None, axis=-1, norm=None,
  78. overwrite_x=False, workers=None, *, plan=None):
  79. return _execute_1D('hfft', _pocketfft.hfft, x, n=n, axis=axis, norm=norm,
  80. overwrite_x=overwrite_x, workers=workers, plan=plan)
  81. def ihfft(x, n=None, axis=-1, norm=None,
  82. overwrite_x=False, workers=None, *, plan=None):
  83. return _execute_1D('ihfft', _pocketfft.ihfft, x, n=n, axis=axis, norm=norm,
  84. overwrite_x=overwrite_x, workers=workers, plan=plan)
  85. def fftn(x, s=None, axes=None, norm=None,
  86. overwrite_x=False, workers=None, *, plan=None):
  87. return _execute_nD('fftn', _pocketfft.fftn, x, s=s, axes=axes, norm=norm,
  88. overwrite_x=overwrite_x, workers=workers, plan=plan)
  89. def ifftn(x, s=None, axes=None, norm=None,
  90. overwrite_x=False, workers=None, *, plan=None):
  91. return _execute_nD('ifftn', _pocketfft.ifftn, x, s=s, axes=axes, norm=norm,
  92. overwrite_x=overwrite_x, workers=workers, plan=plan)
  93. def fft2(x, s=None, axes=(-2, -1), norm=None,
  94. overwrite_x=False, workers=None, *, plan=None):
  95. return fftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
  96. def ifft2(x, s=None, axes=(-2, -1), norm=None,
  97. overwrite_x=False, workers=None, *, plan=None):
  98. return ifftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
  99. def rfftn(x, s=None, axes=None, norm=None,
  100. overwrite_x=False, workers=None, *, plan=None):
  101. return _execute_nD('rfftn', _pocketfft.rfftn, x, s=s, axes=axes, norm=norm,
  102. overwrite_x=overwrite_x, workers=workers, plan=plan)
  103. def rfft2(x, s=None, axes=(-2, -1), norm=None,
  104. overwrite_x=False, workers=None, *, plan=None):
  105. return rfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
  106. def irfftn(x, s=None, axes=None, norm=None,
  107. overwrite_x=False, workers=None, *, plan=None):
  108. return _execute_nD('irfftn', _pocketfft.irfftn, x, s=s, axes=axes, norm=norm,
  109. overwrite_x=overwrite_x, workers=workers, plan=plan)
  110. def irfft2(x, s=None, axes=(-2, -1), norm=None,
  111. overwrite_x=False, workers=None, *, plan=None):
  112. return irfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
  113. def _swap_direction(norm):
  114. if norm in (None, 'backward'):
  115. norm = 'forward'
  116. elif norm == 'forward':
  117. norm = 'backward'
  118. elif norm != 'ortho':
  119. raise ValueError(f'Invalid norm value {norm}; should be "backward", '
  120. '"ortho", or "forward".')
  121. return norm
  122. def hfftn(x, s=None, axes=None, norm=None,
  123. overwrite_x=False, workers=None, *, plan=None):
  124. xp = array_namespace(x)
  125. if is_numpy(xp):
  126. x = np.asarray(x)
  127. return _pocketfft.hfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
  128. if is_complex(x, xp):
  129. x = xp.conj(x)
  130. return irfftn(x, s, axes, _swap_direction(norm),
  131. overwrite_x, workers, plan=plan)
  132. def hfft2(x, s=None, axes=(-2, -1), norm=None,
  133. overwrite_x=False, workers=None, *, plan=None):
  134. return hfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
  135. def ihfftn(x, s=None, axes=None, norm=None,
  136. overwrite_x=False, workers=None, *, plan=None):
  137. xp = array_namespace(x)
  138. if is_numpy(xp):
  139. x = np.asarray(x)
  140. return _pocketfft.ihfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
  141. return xp.conj(rfftn(x, s, axes, _swap_direction(norm),
  142. overwrite_x, workers, plan=plan))
  143. def ihfft2(x, s=None, axes=(-2, -1), norm=None,
  144. overwrite_x=False, workers=None, *, plan=None):
  145. return ihfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)