test_basic.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877
  1. # Created by Pearu Peterson, September 2002
  2. from numpy.testing import (assert_, assert_equal, assert_array_almost_equal,
  3. assert_array_almost_equal_nulp, assert_array_less)
  4. import pytest
  5. from pytest import raises as assert_raises
  6. from scipy.fftpack import ifft, fft, fftn, ifftn, rfft, irfft, fft2
  7. from numpy import (arange, array, asarray, zeros, dot, exp, pi,
  8. swapaxes, double, cdouble)
  9. import numpy as np
  10. import numpy.fft
  11. from numpy.random import rand
  12. # "large" composite numbers supported by FFTPACK
  13. LARGE_COMPOSITE_SIZES = [
  14. 2**13,
  15. 2**5 * 3**5,
  16. 2**3 * 3**3 * 5**2,
  17. ]
  18. SMALL_COMPOSITE_SIZES = [
  19. 2,
  20. 2*3*5,
  21. 2*2*3*3,
  22. ]
  23. # prime
  24. LARGE_PRIME_SIZES = [
  25. 2011
  26. ]
  27. SMALL_PRIME_SIZES = [
  28. 29
  29. ]
  30. def _assert_close_in_norm(x, y, rtol, size, rdt):
  31. # helper function for testing
  32. err_msg = f"size: {size} rdt: {rdt}"
  33. assert_array_less(np.linalg.norm(x - y), rtol*np.linalg.norm(x), err_msg)
  34. def random(size):
  35. return rand(*size)
  36. def direct_dft(x):
  37. x = asarray(x)
  38. n = len(x)
  39. y = zeros(n, dtype=cdouble)
  40. w = -arange(n)*(2j*pi/n)
  41. for i in range(n):
  42. y[i] = dot(exp(i*w), x)
  43. return y
  44. def direct_idft(x):
  45. x = asarray(x)
  46. n = len(x)
  47. y = zeros(n, dtype=cdouble)
  48. w = arange(n)*(2j*pi/n)
  49. for i in range(n):
  50. y[i] = dot(exp(i*w), x)/n
  51. return y
  52. def direct_dftn(x):
  53. x = asarray(x)
  54. for axis in range(len(x.shape)):
  55. x = fft(x, axis=axis)
  56. return x
  57. def direct_idftn(x):
  58. x = asarray(x)
  59. for axis in range(len(x.shape)):
  60. x = ifft(x, axis=axis)
  61. return x
  62. def direct_rdft(x):
  63. x = asarray(x)
  64. n = len(x)
  65. w = -arange(n)*(2j*pi/n)
  66. r = zeros(n, dtype=double)
  67. for i in range(n//2+1):
  68. y = dot(exp(i*w), x)
  69. if i:
  70. r[2*i-1] = y.real
  71. if 2*i < n:
  72. r[2*i] = y.imag
  73. else:
  74. r[0] = y.real
  75. return r
  76. def direct_irdft(x):
  77. x = asarray(x)
  78. n = len(x)
  79. x1 = zeros(n, dtype=cdouble)
  80. for i in range(n//2+1):
  81. if i:
  82. if 2*i < n:
  83. x1[i] = x[2*i-1] + 1j*x[2*i]
  84. x1[n-i] = x[2*i-1] - 1j*x[2*i]
  85. else:
  86. x1[i] = x[2*i-1]
  87. else:
  88. x1[0] = x[0]
  89. return direct_idft(x1).real
  90. class _TestFFTBase:
  91. def setup_method(self):
  92. self.cdt = None
  93. self.rdt = None
  94. np.random.seed(1234)
  95. def test_definition(self):
  96. x = np.array([1,2,3,4+1j,1,2,3,4+2j], dtype=self.cdt)
  97. y = fft(x)
  98. assert_equal(y.dtype, self.cdt)
  99. y1 = direct_dft(x)
  100. assert_array_almost_equal(y,y1)
  101. x = np.array([1,2,3,4+0j,5], dtype=self.cdt)
  102. assert_array_almost_equal(fft(x),direct_dft(x))
  103. def test_n_argument_real(self):
  104. x1 = np.array([1,2,3,4], dtype=self.rdt)
  105. x2 = np.array([1,2,3,4], dtype=self.rdt)
  106. y = fft([x1,x2],n=4)
  107. assert_equal(y.dtype, self.cdt)
  108. assert_equal(y.shape,(2,4))
  109. assert_array_almost_equal(y[0],direct_dft(x1))
  110. assert_array_almost_equal(y[1],direct_dft(x2))
  111. def _test_n_argument_complex(self):
  112. x1 = np.array([1,2,3,4+1j], dtype=self.cdt)
  113. x2 = np.array([1,2,3,4+1j], dtype=self.cdt)
  114. y = fft([x1,x2],n=4)
  115. assert_equal(y.dtype, self.cdt)
  116. assert_equal(y.shape,(2,4))
  117. assert_array_almost_equal(y[0],direct_dft(x1))
  118. assert_array_almost_equal(y[1],direct_dft(x2))
  119. def test_invalid_sizes(self):
  120. assert_raises(ValueError, fft, [])
  121. assert_raises(ValueError, fft, [[1,1],[2,2]], -5)
  122. class TestDoubleFFT(_TestFFTBase):
  123. def setup_method(self):
  124. self.cdt = np.complex128
  125. self.rdt = np.float64
  126. class TestSingleFFT(_TestFFTBase):
  127. def setup_method(self):
  128. self.cdt = np.complex64
  129. self.rdt = np.float32
  130. reason = ("single-precision FFT implementation is partially disabled, "
  131. "until accuracy issues with large prime powers are resolved")
  132. @pytest.mark.xfail(run=False, reason=reason)
  133. def test_notice(self):
  134. pass
  135. class TestFloat16FFT:
  136. def test_1_argument_real(self):
  137. x1 = np.array([1, 2, 3, 4], dtype=np.float16)
  138. y = fft(x1, n=4)
  139. assert_equal(y.dtype, np.complex64)
  140. assert_equal(y.shape, (4, ))
  141. assert_array_almost_equal(y, direct_dft(x1.astype(np.float32)))
  142. def test_n_argument_real(self):
  143. x1 = np.array([1, 2, 3, 4], dtype=np.float16)
  144. x2 = np.array([1, 2, 3, 4], dtype=np.float16)
  145. y = fft([x1, x2], n=4)
  146. assert_equal(y.dtype, np.complex64)
  147. assert_equal(y.shape, (2, 4))
  148. assert_array_almost_equal(y[0], direct_dft(x1.astype(np.float32)))
  149. assert_array_almost_equal(y[1], direct_dft(x2.astype(np.float32)))
  150. class _TestIFFTBase:
  151. def setup_method(self):
  152. np.random.seed(1234)
  153. def test_definition(self):
  154. x = np.array([1,2,3,4+1j,1,2,3,4+2j], self.cdt)
  155. y = ifft(x)
  156. y1 = direct_idft(x)
  157. assert_equal(y.dtype, self.cdt)
  158. assert_array_almost_equal(y,y1)
  159. x = np.array([1,2,3,4+0j,5], self.cdt)
  160. assert_array_almost_equal(ifft(x),direct_idft(x))
  161. def test_definition_real(self):
  162. x = np.array([1,2,3,4,1,2,3,4], self.rdt)
  163. y = ifft(x)
  164. assert_equal(y.dtype, self.cdt)
  165. y1 = direct_idft(x)
  166. assert_array_almost_equal(y,y1)
  167. x = np.array([1,2,3,4,5], dtype=self.rdt)
  168. assert_equal(y.dtype, self.cdt)
  169. assert_array_almost_equal(ifft(x),direct_idft(x))
  170. def test_random_complex(self):
  171. for size in [1,51,111,100,200,64,128,256,1024]:
  172. x = random([size]).astype(self.cdt)
  173. x = random([size]).astype(self.cdt) + 1j*x
  174. y1 = ifft(fft(x))
  175. y2 = fft(ifft(x))
  176. assert_equal(y1.dtype, self.cdt)
  177. assert_equal(y2.dtype, self.cdt)
  178. assert_array_almost_equal(y1, x)
  179. assert_array_almost_equal(y2, x)
  180. def test_random_real(self):
  181. for size in [1,51,111,100,200,64,128,256,1024]:
  182. x = random([size]).astype(self.rdt)
  183. y1 = ifft(fft(x))
  184. y2 = fft(ifft(x))
  185. assert_equal(y1.dtype, self.cdt)
  186. assert_equal(y2.dtype, self.cdt)
  187. assert_array_almost_equal(y1, x)
  188. assert_array_almost_equal(y2, x)
  189. def test_size_accuracy(self):
  190. # Sanity check for the accuracy for prime and non-prime sized inputs
  191. if self.rdt == np.float32:
  192. rtol = 1e-5
  193. elif self.rdt == np.float64:
  194. rtol = 1e-10
  195. for size in LARGE_COMPOSITE_SIZES + LARGE_PRIME_SIZES:
  196. np.random.seed(1234)
  197. x = np.random.rand(size).astype(self.rdt)
  198. y = ifft(fft(x))
  199. _assert_close_in_norm(x, y, rtol, size, self.rdt)
  200. y = fft(ifft(x))
  201. _assert_close_in_norm(x, y, rtol, size, self.rdt)
  202. x = (x + 1j*np.random.rand(size)).astype(self.cdt)
  203. y = ifft(fft(x))
  204. _assert_close_in_norm(x, y, rtol, size, self.rdt)
  205. y = fft(ifft(x))
  206. _assert_close_in_norm(x, y, rtol, size, self.rdt)
  207. def test_invalid_sizes(self):
  208. assert_raises(ValueError, ifft, [])
  209. assert_raises(ValueError, ifft, [[1,1],[2,2]], -5)
  210. class TestDoubleIFFT(_TestIFFTBase):
  211. def setup_method(self):
  212. self.cdt = np.complex128
  213. self.rdt = np.float64
  214. class TestSingleIFFT(_TestIFFTBase):
  215. def setup_method(self):
  216. self.cdt = np.complex64
  217. self.rdt = np.float32
  218. class _TestRFFTBase:
  219. def setup_method(self):
  220. np.random.seed(1234)
  221. def test_definition(self):
  222. for t in [[1, 2, 3, 4, 1, 2, 3, 4], [1, 2, 3, 4, 1, 2, 3, 4, 5]]:
  223. x = np.array(t, dtype=self.rdt)
  224. y = rfft(x)
  225. y1 = direct_rdft(x)
  226. assert_array_almost_equal(y,y1)
  227. assert_equal(y.dtype, self.rdt)
  228. def test_invalid_sizes(self):
  229. assert_raises(ValueError, rfft, [])
  230. assert_raises(ValueError, rfft, [[1,1],[2,2]], -5)
  231. # See gh-5790
  232. class MockSeries:
  233. def __init__(self, data):
  234. self.data = np.asarray(data)
  235. def __getattr__(self, item):
  236. try:
  237. return getattr(self.data, item)
  238. except AttributeError as e:
  239. raise AttributeError("'MockSeries' object "
  240. f"has no attribute '{item}'") from e
  241. def test_non_ndarray_with_dtype(self):
  242. x = np.array([1., 2., 3., 4., 5.])
  243. xs = _TestRFFTBase.MockSeries(x)
  244. expected = [1, 2, 3, 4, 5]
  245. rfft(xs)
  246. # Data should not have been overwritten
  247. assert_equal(x, expected)
  248. assert_equal(xs.data, expected)
  249. def test_complex_input(self):
  250. assert_raises(TypeError, rfft, np.arange(4, dtype=np.complex64))
  251. class TestRFFTDouble(_TestRFFTBase):
  252. def setup_method(self):
  253. self.cdt = np.complex128
  254. self.rdt = np.float64
  255. class TestRFFTSingle(_TestRFFTBase):
  256. def setup_method(self):
  257. self.cdt = np.complex64
  258. self.rdt = np.float32
  259. class _TestIRFFTBase:
  260. def setup_method(self):
  261. np.random.seed(1234)
  262. def test_definition(self):
  263. x1 = [1,2,3,4,1,2,3,4]
  264. x1_1 = [1,2+3j,4+1j,2+3j,4,2-3j,4-1j,2-3j]
  265. x2 = [1,2,3,4,1,2,3,4,5]
  266. x2_1 = [1,2+3j,4+1j,2+3j,4+5j,4-5j,2-3j,4-1j,2-3j]
  267. def _test(x, xr):
  268. y = irfft(np.array(x, dtype=self.rdt))
  269. y1 = direct_irdft(x)
  270. assert_equal(y.dtype, self.rdt)
  271. assert_array_almost_equal(y,y1, decimal=self.ndec)
  272. assert_array_almost_equal(y,ifft(xr), decimal=self.ndec)
  273. _test(x1, x1_1)
  274. _test(x2, x2_1)
  275. def test_random_real(self):
  276. for size in [1,51,111,100,200,64,128,256,1024]:
  277. x = random([size]).astype(self.rdt)
  278. y1 = irfft(rfft(x))
  279. y2 = rfft(irfft(x))
  280. assert_equal(y1.dtype, self.rdt)
  281. assert_equal(y2.dtype, self.rdt)
  282. assert_array_almost_equal(y1, x, decimal=self.ndec, err_msg=f"size={size}")
  283. assert_array_almost_equal(y2, x, decimal=self.ndec, err_msg=f"size={size}")
  284. def test_size_accuracy(self):
  285. # Sanity check for the accuracy for prime and non-prime sized inputs
  286. if self.rdt == np.float32:
  287. rtol = 1e-5
  288. elif self.rdt == np.float64:
  289. rtol = 1e-10
  290. for size in LARGE_COMPOSITE_SIZES + LARGE_PRIME_SIZES:
  291. np.random.seed(1234)
  292. x = np.random.rand(size).astype(self.rdt)
  293. y = irfft(rfft(x))
  294. _assert_close_in_norm(x, y, rtol, size, self.rdt)
  295. y = rfft(irfft(x))
  296. _assert_close_in_norm(x, y, rtol, size, self.rdt)
  297. def test_invalid_sizes(self):
  298. assert_raises(ValueError, irfft, [])
  299. assert_raises(ValueError, irfft, [[1,1],[2,2]], -5)
  300. def test_complex_input(self):
  301. assert_raises(TypeError, irfft, np.arange(4, dtype=np.complex64))
  302. # self.ndec is bogus; we should have a assert_array_approx_equal for number of
  303. # significant digits
  304. class TestIRFFTDouble(_TestIRFFTBase):
  305. def setup_method(self):
  306. self.cdt = np.complex128
  307. self.rdt = np.float64
  308. self.ndec = 14
  309. class TestIRFFTSingle(_TestIRFFTBase):
  310. def setup_method(self):
  311. self.cdt = np.complex64
  312. self.rdt = np.float32
  313. self.ndec = 5
  314. class Testfft2:
  315. def setup_method(self):
  316. np.random.seed(1234)
  317. def test_regression_244(self):
  318. """FFT returns wrong result with axes parameter."""
  319. # fftn (and hence fft2) used to break when both axes and shape were
  320. # used
  321. x = numpy.ones((4, 4, 2))
  322. y = fft2(x, shape=(8, 8), axes=(-3, -2))
  323. y_r = numpy.fft.fftn(x, s=(8, 8), axes=(-3, -2))
  324. assert_array_almost_equal(y, y_r)
  325. def test_invalid_sizes(self):
  326. assert_raises(ValueError, fft2, [[]])
  327. assert_raises(ValueError, fft2, [[1, 1], [2, 2]], (4, -3))
  328. class TestFftnSingle:
  329. def setup_method(self):
  330. np.random.seed(1234)
  331. def test_definition(self):
  332. x = [[1, 2, 3],
  333. [4, 5, 6],
  334. [7, 8, 9]]
  335. y = fftn(np.array(x, np.float32))
  336. assert_(y.dtype == np.complex64,
  337. msg="double precision output with single precision")
  338. y_r = np.array(fftn(x), np.complex64)
  339. assert_array_almost_equal_nulp(y, y_r)
  340. @pytest.mark.parametrize('size', SMALL_COMPOSITE_SIZES + SMALL_PRIME_SIZES)
  341. def test_size_accuracy_small(self, size):
  342. rng = np.random.default_rng(1234)
  343. x = rng.random((size, size)) + 1j*rng.random((size, size))
  344. y1 = fftn(x.real.astype(np.float32))
  345. y2 = fftn(x.real.astype(np.float64)).astype(np.complex64)
  346. assert_equal(y1.dtype, np.complex64)
  347. assert_array_almost_equal_nulp(y1, y2, 2000)
  348. @pytest.mark.parametrize('size', LARGE_COMPOSITE_SIZES + LARGE_PRIME_SIZES)
  349. def test_size_accuracy_large(self, size):
  350. rand = np.random.default_rng(1234)
  351. x = rand.random((size, 3)) + 1j*rand.random((size, 3))
  352. y1 = fftn(x.real.astype(np.float32))
  353. y2 = fftn(x.real.astype(np.float64)).astype(np.complex64)
  354. assert_equal(y1.dtype, np.complex64)
  355. assert_array_almost_equal_nulp(y1, y2, 2000)
  356. def test_definition_float16(self):
  357. x = [[1, 2, 3],
  358. [4, 5, 6],
  359. [7, 8, 9]]
  360. y = fftn(np.array(x, np.float16))
  361. assert_equal(y.dtype, np.complex64)
  362. y_r = np.array(fftn(x), np.complex64)
  363. assert_array_almost_equal_nulp(y, y_r)
  364. @pytest.mark.parametrize('size', SMALL_COMPOSITE_SIZES + SMALL_PRIME_SIZES)
  365. def test_float16_input_small(self, size):
  366. rng = np.random.default_rng(1234)
  367. x = rng.random((size, size)) + 1j * rng.random((size, size))
  368. y1 = fftn(x.real.astype(np.float16))
  369. y2 = fftn(x.real.astype(np.float64)).astype(np.complex64)
  370. assert_equal(y1.dtype, np.complex64)
  371. assert_array_almost_equal_nulp(y1, y2, 5e5)
  372. @pytest.mark.parametrize('size', LARGE_COMPOSITE_SIZES + LARGE_PRIME_SIZES)
  373. def test_float16_input_large(self, size):
  374. rng = np.random.default_rng(1234)
  375. x = rng.random((size, 3)) + 1j*rng.random((size, 3))
  376. y1 = fftn(x.real.astype(np.float16))
  377. y2 = fftn(x.real.astype(np.float64)).astype(np.complex64)
  378. assert_equal(y1.dtype, np.complex64)
  379. assert_array_almost_equal_nulp(y1, y2, 2e6)
  380. class TestFftn:
  381. def setup_method(self):
  382. np.random.seed(1234)
  383. def test_definition(self):
  384. x = [[1, 2, 3],
  385. [4, 5, 6],
  386. [7, 8, 9]]
  387. y = fftn(x)
  388. assert_array_almost_equal(y, direct_dftn(x))
  389. x = random((20, 26))
  390. assert_array_almost_equal(fftn(x), direct_dftn(x))
  391. x = random((5, 4, 3, 20))
  392. assert_array_almost_equal(fftn(x), direct_dftn(x))
  393. def test_axes_argument(self):
  394. # plane == ji_plane, x== kji_space
  395. plane1 = [[1, 2, 3],
  396. [4, 5, 6],
  397. [7, 8, 9]]
  398. plane2 = [[10, 11, 12],
  399. [13, 14, 15],
  400. [16, 17, 18]]
  401. plane3 = [[19, 20, 21],
  402. [22, 23, 24],
  403. [25, 26, 27]]
  404. ki_plane1 = [[1, 2, 3],
  405. [10, 11, 12],
  406. [19, 20, 21]]
  407. ki_plane2 = [[4, 5, 6],
  408. [13, 14, 15],
  409. [22, 23, 24]]
  410. ki_plane3 = [[7, 8, 9],
  411. [16, 17, 18],
  412. [25, 26, 27]]
  413. jk_plane1 = [[1, 10, 19],
  414. [4, 13, 22],
  415. [7, 16, 25]]
  416. jk_plane2 = [[2, 11, 20],
  417. [5, 14, 23],
  418. [8, 17, 26]]
  419. jk_plane3 = [[3, 12, 21],
  420. [6, 15, 24],
  421. [9, 18, 27]]
  422. kj_plane1 = [[1, 4, 7],
  423. [10, 13, 16], [19, 22, 25]]
  424. kj_plane2 = [[2, 5, 8],
  425. [11, 14, 17], [20, 23, 26]]
  426. kj_plane3 = [[3, 6, 9],
  427. [12, 15, 18], [21, 24, 27]]
  428. ij_plane1 = [[1, 4, 7],
  429. [2, 5, 8],
  430. [3, 6, 9]]
  431. ij_plane2 = [[10, 13, 16],
  432. [11, 14, 17],
  433. [12, 15, 18]]
  434. ij_plane3 = [[19, 22, 25],
  435. [20, 23, 26],
  436. [21, 24, 27]]
  437. ik_plane1 = [[1, 10, 19],
  438. [2, 11, 20],
  439. [3, 12, 21]]
  440. ik_plane2 = [[4, 13, 22],
  441. [5, 14, 23],
  442. [6, 15, 24]]
  443. ik_plane3 = [[7, 16, 25],
  444. [8, 17, 26],
  445. [9, 18, 27]]
  446. ijk_space = [jk_plane1, jk_plane2, jk_plane3]
  447. ikj_space = [kj_plane1, kj_plane2, kj_plane3]
  448. jik_space = [ik_plane1, ik_plane2, ik_plane3]
  449. jki_space = [ki_plane1, ki_plane2, ki_plane3]
  450. kij_space = [ij_plane1, ij_plane2, ij_plane3]
  451. x = array([plane1, plane2, plane3])
  452. assert_array_almost_equal(fftn(x),
  453. fftn(x, axes=(-3, -2, -1))) # kji_space
  454. assert_array_almost_equal(fftn(x), fftn(x, axes=(0, 1, 2)))
  455. assert_array_almost_equal(fftn(x, axes=(0, 2)), fftn(x, axes=(0, -1)))
  456. y = fftn(x, axes=(2, 1, 0)) # ijk_space
  457. assert_array_almost_equal(swapaxes(y, -1, -3), fftn(ijk_space))
  458. y = fftn(x, axes=(2, 0, 1)) # ikj_space
  459. assert_array_almost_equal(swapaxes(swapaxes(y, -1, -3), -1, -2),
  460. fftn(ikj_space))
  461. y = fftn(x, axes=(1, 2, 0)) # jik_space
  462. assert_array_almost_equal(swapaxes(swapaxes(y, -1, -3), -3, -2),
  463. fftn(jik_space))
  464. y = fftn(x, axes=(1, 0, 2)) # jki_space
  465. assert_array_almost_equal(swapaxes(y, -2, -3), fftn(jki_space))
  466. y = fftn(x, axes=(0, 2, 1)) # kij_space
  467. assert_array_almost_equal(swapaxes(y, -2, -1), fftn(kij_space))
  468. y = fftn(x, axes=(-2, -1)) # ji_plane
  469. assert_array_almost_equal(fftn(plane1), y[0])
  470. assert_array_almost_equal(fftn(plane2), y[1])
  471. assert_array_almost_equal(fftn(plane3), y[2])
  472. y = fftn(x, axes=(1, 2)) # ji_plane
  473. assert_array_almost_equal(fftn(plane1), y[0])
  474. assert_array_almost_equal(fftn(plane2), y[1])
  475. assert_array_almost_equal(fftn(plane3), y[2])
  476. y = fftn(x, axes=(-3, -2)) # kj_plane
  477. assert_array_almost_equal(fftn(x[:, :, 0]), y[:, :, 0])
  478. assert_array_almost_equal(fftn(x[:, :, 1]), y[:, :, 1])
  479. assert_array_almost_equal(fftn(x[:, :, 2]), y[:, :, 2])
  480. y = fftn(x, axes=(-3, -1)) # ki_plane
  481. assert_array_almost_equal(fftn(x[:, 0, :]), y[:, 0, :])
  482. assert_array_almost_equal(fftn(x[:, 1, :]), y[:, 1, :])
  483. assert_array_almost_equal(fftn(x[:, 2, :]), y[:, 2, :])
  484. y = fftn(x, axes=(-1, -2)) # ij_plane
  485. assert_array_almost_equal(fftn(ij_plane1), swapaxes(y[0], -2, -1))
  486. assert_array_almost_equal(fftn(ij_plane2), swapaxes(y[1], -2, -1))
  487. assert_array_almost_equal(fftn(ij_plane3), swapaxes(y[2], -2, -1))
  488. y = fftn(x, axes=(-1, -3)) # ik_plane
  489. assert_array_almost_equal(fftn(ik_plane1),
  490. swapaxes(y[:, 0, :], -1, -2))
  491. assert_array_almost_equal(fftn(ik_plane2),
  492. swapaxes(y[:, 1, :], -1, -2))
  493. assert_array_almost_equal(fftn(ik_plane3),
  494. swapaxes(y[:, 2, :], -1, -2))
  495. y = fftn(x, axes=(-2, -3)) # jk_plane
  496. assert_array_almost_equal(fftn(jk_plane1),
  497. swapaxes(y[:, :, 0], -1, -2))
  498. assert_array_almost_equal(fftn(jk_plane2),
  499. swapaxes(y[:, :, 1], -1, -2))
  500. assert_array_almost_equal(fftn(jk_plane3),
  501. swapaxes(y[:, :, 2], -1, -2))
  502. y = fftn(x, axes=(-1,)) # i_line
  503. for i in range(3):
  504. for j in range(3):
  505. assert_array_almost_equal(fft(x[i, j, :]), y[i, j, :])
  506. y = fftn(x, axes=(-2,)) # j_line
  507. for i in range(3):
  508. for j in range(3):
  509. assert_array_almost_equal(fft(x[i, :, j]), y[i, :, j])
  510. y = fftn(x, axes=(0,)) # k_line
  511. for i in range(3):
  512. for j in range(3):
  513. assert_array_almost_equal(fft(x[:, i, j]), y[:, i, j])
  514. y = fftn(x, axes=()) # point
  515. assert_array_almost_equal(y, x)
  516. def test_shape_argument(self):
  517. small_x = [[1, 2, 3],
  518. [4, 5, 6]]
  519. large_x1 = [[1, 2, 3, 0],
  520. [4, 5, 6, 0],
  521. [0, 0, 0, 0],
  522. [0, 0, 0, 0]]
  523. y = fftn(small_x, shape=(4, 4))
  524. assert_array_almost_equal(y, fftn(large_x1))
  525. y = fftn(small_x, shape=(3, 4))
  526. assert_array_almost_equal(y, fftn(large_x1[:-1]))
  527. def test_shape_axes_argument(self):
  528. small_x = [[1, 2, 3],
  529. [4, 5, 6],
  530. [7, 8, 9]]
  531. large_x1 = array([[1, 2, 3, 0],
  532. [4, 5, 6, 0],
  533. [7, 8, 9, 0],
  534. [0, 0, 0, 0]])
  535. y = fftn(small_x, shape=(4, 4), axes=(-2, -1))
  536. assert_array_almost_equal(y, fftn(large_x1))
  537. y = fftn(small_x, shape=(4, 4), axes=(-1, -2))
  538. assert_array_almost_equal(y, swapaxes(
  539. fftn(swapaxes(large_x1, -1, -2)), -1, -2))
  540. def test_shape_axes_argument2(self):
  541. # Change shape of the last axis
  542. x = numpy.random.random((10, 5, 3, 7))
  543. y = fftn(x, axes=(-1,), shape=(8,))
  544. assert_array_almost_equal(y, fft(x, axis=-1, n=8))
  545. # Change shape of an arbitrary axis which is not the last one
  546. x = numpy.random.random((10, 5, 3, 7))
  547. y = fftn(x, axes=(-2,), shape=(8,))
  548. assert_array_almost_equal(y, fft(x, axis=-2, n=8))
  549. # Change shape of axes: cf #244, where shape and axes were mixed up
  550. x = numpy.random.random((4, 4, 2))
  551. y = fftn(x, axes=(-3, -2), shape=(8, 8))
  552. assert_array_almost_equal(y,
  553. numpy.fft.fftn(x, axes=(-3, -2), s=(8, 8)))
  554. def test_shape_argument_more(self):
  555. x = zeros((4, 4, 2))
  556. with assert_raises(ValueError,
  557. match="when given, axes and shape arguments"
  558. " have to be of the same length"):
  559. fftn(x, shape=(8, 8, 2, 1))
  560. def test_invalid_sizes(self):
  561. with assert_raises(ValueError,
  562. match="invalid number of data points"
  563. r" \(\[1, 0\]\) specified"):
  564. fftn([[]])
  565. with assert_raises(ValueError,
  566. match="invalid number of data points"
  567. r" \(\[4, -3\]\) specified"):
  568. fftn([[1, 1], [2, 2]], (4, -3))
  569. class TestIfftn:
  570. dtype = None
  571. cdtype = None
  572. def setup_method(self):
  573. np.random.seed(1234)
  574. @pytest.mark.parametrize('dtype,cdtype,maxnlp',
  575. [(np.float64, np.complex128, 2000),
  576. (np.float32, np.complex64, 3500)])
  577. def test_definition(self, dtype, cdtype, maxnlp):
  578. rng = np.random.default_rng(1234)
  579. x = np.array([[1, 2, 3],
  580. [4, 5, 6],
  581. [7, 8, 9]], dtype=dtype)
  582. y = ifftn(x)
  583. assert_equal(y.dtype, cdtype)
  584. assert_array_almost_equal_nulp(y, direct_idftn(x), maxnlp)
  585. x = rng.random((20, 26))
  586. assert_array_almost_equal_nulp(ifftn(x), direct_idftn(x), maxnlp)
  587. x = rng.random((5, 4, 3, 20))
  588. assert_array_almost_equal_nulp(ifftn(x), direct_idftn(x), maxnlp)
  589. @pytest.mark.parametrize('maxnlp', [2000, 3500])
  590. @pytest.mark.parametrize('size', [1, 2, 51, 32, 64, 92])
  591. def test_random_complex(self, maxnlp, size):
  592. rng = np.random.default_rng(1234)
  593. x = rng.random([size, size]) + 1j * rng.random([size, size])
  594. assert_array_almost_equal_nulp(ifftn(fftn(x)), x, maxnlp)
  595. assert_array_almost_equal_nulp(fftn(ifftn(x)), x, maxnlp)
  596. def test_invalid_sizes(self):
  597. with assert_raises(ValueError,
  598. match="invalid number of data points"
  599. r" \(\[1, 0\]\) specified"):
  600. ifftn([[]])
  601. with assert_raises(ValueError,
  602. match="invalid number of data points"
  603. r" \(\[4, -3\]\) specified"):
  604. ifftn([[1, 1], [2, 2]], (4, -3))
  605. class FakeArray:
  606. def __init__(self, data):
  607. self._data = data
  608. self.__array_interface__ = data.__array_interface__
  609. class FakeArray2:
  610. def __init__(self, data):
  611. self._data = data
  612. def __array__(self, dtype=None, copy=None):
  613. return self._data
  614. class TestOverwrite:
  615. """Check input overwrite behavior of the FFT functions."""
  616. real_dtypes = (np.float32, np.float64)
  617. dtypes = real_dtypes + (np.complex64, np.complex128)
  618. fftsizes = [8, 16, 32]
  619. def _check(self, x, routine, fftsize, axis, overwrite_x):
  620. x2 = x.copy()
  621. for fake in [lambda x: x, FakeArray, FakeArray2]:
  622. routine(fake(x2), fftsize, axis, overwrite_x=overwrite_x)
  623. sig = (f"{routine.__name__}({x.dtype}{x.shape!r}, {fftsize!r}, "
  624. f"axis={axis!r}, overwrite_x={overwrite_x!r})")
  625. if not overwrite_x:
  626. assert_equal(x2, x, err_msg=f"spurious overwrite in {sig}")
  627. def _check_1d(self, routine, dtype, shape, axis, overwritable_dtypes,
  628. fftsize, overwrite_x):
  629. np.random.seed(1234)
  630. if np.issubdtype(dtype, np.complexfloating):
  631. data = np.random.randn(*shape) + 1j*np.random.randn(*shape)
  632. else:
  633. data = np.random.randn(*shape)
  634. data = data.astype(dtype)
  635. self._check(data, routine, fftsize, axis,
  636. overwrite_x=overwrite_x)
  637. @pytest.mark.parametrize('dtype', dtypes)
  638. @pytest.mark.parametrize('fftsize', fftsizes)
  639. @pytest.mark.parametrize('overwrite_x', [True, False])
  640. @pytest.mark.parametrize('shape,axes', [((16,), -1),
  641. ((16, 2), 0),
  642. ((2, 16), 1)])
  643. def test_fft_ifft(self, dtype, fftsize, overwrite_x, shape, axes):
  644. overwritable = (np.complex128, np.complex64)
  645. self._check_1d(fft, dtype, shape, axes, overwritable,
  646. fftsize, overwrite_x)
  647. self._check_1d(ifft, dtype, shape, axes, overwritable,
  648. fftsize, overwrite_x)
  649. @pytest.mark.parametrize('dtype', real_dtypes)
  650. @pytest.mark.parametrize('fftsize', fftsizes)
  651. @pytest.mark.parametrize('overwrite_x', [True, False])
  652. @pytest.mark.parametrize('shape,axes', [((16,), -1),
  653. ((16, 2), 0),
  654. ((2, 16), 1)])
  655. def test_rfft_irfft(self, dtype, fftsize, overwrite_x, shape, axes):
  656. overwritable = self.real_dtypes
  657. self._check_1d(irfft, dtype, shape, axes, overwritable,
  658. fftsize, overwrite_x)
  659. self._check_1d(rfft, dtype, shape, axes, overwritable,
  660. fftsize, overwrite_x)
  661. def _check_nd_one(self, routine, dtype, shape, axes, overwritable_dtypes,
  662. overwrite_x):
  663. np.random.seed(1234)
  664. if np.issubdtype(dtype, np.complexfloating):
  665. data = np.random.randn(*shape) + 1j*np.random.randn(*shape)
  666. else:
  667. data = np.random.randn(*shape)
  668. data = data.astype(dtype)
  669. def fftshape_iter(shp):
  670. if len(shp) <= 0:
  671. yield ()
  672. else:
  673. for j in (shp[0]//2, shp[0], shp[0]*2):
  674. for rest in fftshape_iter(shp[1:]):
  675. yield (j,) + rest
  676. if axes is None:
  677. part_shape = shape
  678. else:
  679. part_shape = tuple(np.take(shape, axes))
  680. for fftshape in fftshape_iter(part_shape):
  681. self._check(data, routine, fftshape, axes,
  682. overwrite_x=overwrite_x)
  683. if data.ndim > 1:
  684. self._check(data.T, routine, fftshape, axes,
  685. overwrite_x=overwrite_x)
  686. @pytest.mark.parametrize('dtype', dtypes)
  687. @pytest.mark.parametrize('overwrite_x', [True, False])
  688. @pytest.mark.parametrize('shape,axes', [((16,), None),
  689. ((16,), (0,)),
  690. ((16, 2), (0,)),
  691. ((2, 16), (1,)),
  692. ((8, 16), None),
  693. ((8, 16), (0, 1)),
  694. ((8, 16, 2), (0, 1)),
  695. ((8, 16, 2), (1, 2)),
  696. ((8, 16, 2), (0,)),
  697. ((8, 16, 2), (1,)),
  698. ((8, 16, 2), (2,)),
  699. ((8, 16, 2), None),
  700. ((8, 16, 2), (0, 1, 2))])
  701. def test_fftn_ifftn(self, dtype, overwrite_x, shape, axes):
  702. overwritable = (np.complex128, np.complex64)
  703. self._check_nd_one(fftn, dtype, shape, axes, overwritable,
  704. overwrite_x)
  705. self._check_nd_one(ifftn, dtype, shape, axes, overwritable,
  706. overwrite_x)
  707. @pytest.mark.parametrize('func', [fftn, ifftn, fft2])
  708. def test_shape_axes_ndarray(func):
  709. # Test fftn and ifftn work with NumPy arrays for shape and axes arguments
  710. # Regression test for gh-13342
  711. a = np.random.rand(10, 10)
  712. expect = func(a, shape=(5, 5))
  713. actual = func(a, shape=np.array([5, 5]))
  714. assert_equal(expect, actual)
  715. expect = func(a, axes=(-1,))
  716. actual = func(a, axes=np.array([-1,]))
  717. assert_equal(expect, actual)
  718. expect = func(a, shape=(4, 7), axes=(1, 0))
  719. actual = func(a, shape=np.array([4, 7]), axes=np.array([1, 0]))
  720. assert_equal(expect, actual)