test_helper.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. """Includes test functions for fftpack.helper module
  2. Copied from fftpack.helper by Pearu Peterson, October 2005
  3. Modified for Array API, 2023
  4. """
  5. from scipy.fft._helper import next_fast_len, prev_fast_len, _init_nd_shape_and_axes
  6. from numpy.testing import assert_equal
  7. from pytest import raises as assert_raises
  8. import pytest
  9. import numpy as np
  10. import sys
  11. from scipy._lib._array_api import xp_assert_close, xp_device
  12. from scipy import fft
  13. skip_xp_backends = pytest.mark.skip_xp_backends
  14. _5_smooth_numbers = [
  15. 2, 3, 4, 5, 6, 8, 9, 10,
  16. 2 * 3 * 5,
  17. 2**3 * 3**5,
  18. 2**3 * 3**3 * 5**2,
  19. ]
  20. def test_next_fast_len():
  21. for n in _5_smooth_numbers:
  22. assert_equal(next_fast_len(n), n)
  23. def _assert_n_smooth(x, n):
  24. x_orig = x
  25. if n < 2:
  26. assert False
  27. while True:
  28. q, r = divmod(x, 2)
  29. if r != 0:
  30. break
  31. x = q
  32. for d in range(3, n+1, 2):
  33. while True:
  34. q, r = divmod(x, d)
  35. if r != 0:
  36. break
  37. x = q
  38. assert x == 1, \
  39. f'x={x_orig} is not {n}-smooth, remainder={x}'
  40. class TestNextFastLen:
  41. def test_next_fast_len(self):
  42. np.random.seed(1234)
  43. def nums():
  44. yield from range(1, 1000)
  45. yield 2**5 * 3**5 * 4**5 + 1
  46. for n in nums():
  47. m = next_fast_len(n)
  48. _assert_n_smooth(m, 11)
  49. assert m == next_fast_len(n, False)
  50. m = next_fast_len(n, True)
  51. _assert_n_smooth(m, 5)
  52. def test_np_integers(self):
  53. ITYPES = [np.int16, np.int32, np.int64, np.uint16, np.uint32, np.uint64]
  54. for ityp in ITYPES:
  55. x = ityp(12345)
  56. testN = next_fast_len(x)
  57. assert_equal(testN, next_fast_len(int(x)))
  58. def testnext_fast_len_small(self):
  59. hams = {
  60. 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 8, 8: 8, 14: 15, 15: 15,
  61. 16: 16, 17: 18, 1021: 1024, 1536: 1536, 51200000: 51200000
  62. }
  63. for x, y in hams.items():
  64. assert_equal(next_fast_len(x, True), y)
  65. @pytest.mark.xfail(sys.maxsize < 2**32,
  66. reason="Hamming Numbers too large for 32-bit",
  67. raises=ValueError, strict=True)
  68. def testnext_fast_len_big(self):
  69. hams = {
  70. 510183360: 510183360, 510183360 + 1: 512000000,
  71. 511000000: 512000000,
  72. 854296875: 854296875, 854296875 + 1: 859963392,
  73. 196608000000: 196608000000, 196608000000 + 1: 196830000000,
  74. 8789062500000: 8789062500000, 8789062500000 + 1: 8796093022208,
  75. 206391214080000: 206391214080000,
  76. 206391214080000 + 1: 206624260800000,
  77. 470184984576000: 470184984576000,
  78. 470184984576000 + 1: 470715894135000,
  79. 7222041363087360: 7222041363087360,
  80. 7222041363087360 + 1: 7230196133913600,
  81. # power of 5 5**23
  82. 11920928955078125: 11920928955078125,
  83. 11920928955078125 - 1: 11920928955078125,
  84. # power of 3 3**34
  85. 16677181699666569: 16677181699666569,
  86. 16677181699666569 - 1: 16677181699666569,
  87. # power of 2 2**54
  88. 18014398509481984: 18014398509481984,
  89. 18014398509481984 - 1: 18014398509481984,
  90. # above this, int(ceil(n)) == int(ceil(n+1))
  91. 19200000000000000: 19200000000000000,
  92. 19200000000000000 + 1: 19221679687500000,
  93. 288230376151711744: 288230376151711744,
  94. 288230376151711744 + 1: 288325195312500000,
  95. 288325195312500000 - 1: 288325195312500000,
  96. 288325195312500000: 288325195312500000,
  97. 288325195312500000 + 1: 288555831593533440,
  98. }
  99. for x, y in hams.items():
  100. assert_equal(next_fast_len(x, True), y)
  101. def test_keyword_args(self, xp):
  102. assert next_fast_len(11, real=True) == 12
  103. assert next_fast_len(target=7, real=False) == 7
  104. class TestPrevFastLen:
  105. def test_prev_fast_len(self):
  106. np.random.seed(1234)
  107. def nums():
  108. yield from range(1, 1000)
  109. yield 2**5 * 3**5 * 4**5 + 1
  110. for n in nums():
  111. m = prev_fast_len(n)
  112. _assert_n_smooth(m, 11)
  113. assert m == prev_fast_len(n, False)
  114. m = prev_fast_len(n, True)
  115. _assert_n_smooth(m, 5)
  116. def test_np_integers(self):
  117. ITYPES = [np.int16, np.int32, np.int64, np.uint16, np.uint32,
  118. np.uint64]
  119. for ityp in ITYPES:
  120. x = ityp(12345)
  121. testN = prev_fast_len(x)
  122. assert_equal(testN, prev_fast_len(int(x)))
  123. testN = prev_fast_len(x, real=True)
  124. assert_equal(testN, prev_fast_len(int(x), real=True))
  125. def testprev_fast_len_small(self):
  126. hams = {
  127. 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 6, 8: 8, 14: 12, 15: 15,
  128. 16: 16, 17: 16, 1021: 1000, 1536: 1536, 51200000: 51200000
  129. }
  130. for x, y in hams.items():
  131. assert_equal(prev_fast_len(x, True), y)
  132. hams = {
  133. 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10,
  134. 11: 11, 12: 12, 13: 12, 14: 14, 15: 15, 16: 16, 17: 16, 18: 18,
  135. 19: 18, 20: 20, 21: 21, 22: 22, 120: 120, 121: 121, 122: 121,
  136. 1021: 1008, 1536: 1536, 51200000: 51200000
  137. }
  138. for x, y in hams.items():
  139. assert_equal(prev_fast_len(x, False), y)
  140. @pytest.mark.xfail(sys.maxsize < 2**32,
  141. reason="Hamming Numbers too large for 32-bit",
  142. raises=ValueError, strict=True)
  143. def testprev_fast_len_big(self):
  144. hams = {
  145. # 2**6 * 3**13 * 5**1
  146. 510183360: 510183360,
  147. 510183360 + 1: 510183360,
  148. 510183360 - 1: 509607936, # 2**21 * 3**5
  149. # 2**6 * 5**6 * 7**1 * 73**1
  150. 511000000: 510183360,
  151. 511000000 + 1: 510183360,
  152. 511000000 - 1: 510183360, # 2**6 * 3**13 * 5**1
  153. # 3**7 * 5**8
  154. 854296875: 854296875,
  155. 854296875 + 1: 854296875,
  156. 854296875 - 1: 850305600, # 2**6 * 3**12 * 5**2
  157. # 2**22 * 3**1 * 5**6
  158. 196608000000: 196608000000,
  159. 196608000000 + 1: 196608000000,
  160. 196608000000 - 1: 195910410240, # 2**13 * 3**14 * 5**1
  161. # 2**5 * 3**2 * 5**15
  162. 8789062500000: 8789062500000,
  163. 8789062500000 + 1: 8789062500000,
  164. 8789062500000 - 1: 8748000000000, # 2**11 * 3**7 * 5**9
  165. # 2**24 * 3**9 * 5**4
  166. 206391214080000: 206391214080000,
  167. 206391214080000 + 1: 206391214080000,
  168. 206391214080000 - 1: 206158430208000, # 2**39 * 3**1 * 5**3
  169. # 2**18 * 3**15 * 5**3
  170. 470184984576000: 470184984576000,
  171. 470184984576000 + 1: 470184984576000,
  172. 470184984576000 - 1: 469654673817600, # 2**33 * 3**7 **5**2
  173. # 2**25 * 3**16 * 5**1
  174. 7222041363087360: 7222041363087360,
  175. 7222041363087360 + 1: 7222041363087360,
  176. 7222041363087360 - 1: 7213895789838336, # 2**40 * 3**8
  177. # power of 5 5**23
  178. 11920928955078125: 11920928955078125,
  179. 11920928955078125 + 1: 11920928955078125,
  180. 11920928955078125 - 1: 11901557422080000, # 2**14 * 3**19 * 5**4
  181. # power of 3 3**34
  182. 16677181699666569: 16677181699666569,
  183. 16677181699666569 + 1: 16677181699666569,
  184. 16677181699666569 - 1: 16607531250000000, # 2**7 * 3**12 * 5**12
  185. # power of 2 2**54
  186. 18014398509481984: 18014398509481984,
  187. 18014398509481984 + 1: 18014398509481984,
  188. 18014398509481984 - 1: 18000000000000000, # 2**16 * 3**2 * 5**15
  189. # 2**20 * 3**1 * 5**14
  190. 19200000000000000: 19200000000000000,
  191. 19200000000000000 + 1: 19200000000000000,
  192. 19200000000000000 - 1: 19131876000000000, # 2**11 * 3**14 * 5**9
  193. # 2**58
  194. 288230376151711744: 288230376151711744,
  195. 288230376151711744 + 1: 288230376151711744,
  196. 288230376151711744 - 1: 288000000000000000, # 2**20 * 3**2 * 5**15
  197. # 2**5 * 3**10 * 5**16
  198. 288325195312500000: 288325195312500000,
  199. 288325195312500000 + 1: 288325195312500000,
  200. 288325195312500000 - 1: 288230376151711744, # 2**58
  201. }
  202. for x, y in hams.items():
  203. assert_equal(prev_fast_len(x, True), y)
  204. def test_keyword_args(self):
  205. assert prev_fast_len(11, real=True) == 10
  206. assert prev_fast_len(target=7, real=False) == 7
  207. @skip_xp_backends(cpu_only=True)
  208. class Test_init_nd_shape_and_axes:
  209. def test_py_0d_defaults(self, xp):
  210. x = xp.asarray(4)
  211. shape = None
  212. axes = None
  213. shape_expected = ()
  214. axes_expected = []
  215. shape_res, axes_res = _init_nd_shape_and_axes(x, shape, axes)
  216. assert shape_res == shape_expected
  217. assert axes_res == axes_expected
  218. def test_xp_0d_defaults(self, xp):
  219. x = xp.asarray(7.)
  220. shape = None
  221. axes = None
  222. shape_expected = ()
  223. axes_expected = []
  224. shape_res, axes_res = _init_nd_shape_and_axes(x, shape, axes)
  225. assert shape_res == shape_expected
  226. assert axes_res == axes_expected
  227. def test_py_1d_defaults(self, xp):
  228. x = xp.asarray([1, 2, 3])
  229. shape = None
  230. axes = None
  231. shape_expected = (3,)
  232. axes_expected = [0]
  233. shape_res, axes_res = _init_nd_shape_and_axes(x, shape, axes)
  234. assert shape_res == shape_expected
  235. assert axes_res == axes_expected
  236. def test_xp_1d_defaults(self, xp):
  237. x = xp.arange(0, 1, .1)
  238. shape = None
  239. axes = None
  240. shape_expected = (10,)
  241. axes_expected = [0]
  242. shape_res, axes_res = _init_nd_shape_and_axes(x, shape, axes)
  243. assert shape_res == shape_expected
  244. assert axes_res == axes_expected
  245. def test_py_2d_defaults(self, xp):
  246. x = xp.asarray([[1, 2, 3, 4],
  247. [5, 6, 7, 8]])
  248. shape = None
  249. axes = None
  250. shape_expected = (2, 4)
  251. axes_expected = [0, 1]
  252. shape_res, axes_res = _init_nd_shape_and_axes(x, shape, axes)
  253. assert shape_res == shape_expected
  254. assert axes_res == axes_expected
  255. def test_xp_2d_defaults(self, xp):
  256. x = xp.arange(0, 1, .1)
  257. x = xp.reshape(x, (5, 2))
  258. shape = None
  259. axes = None
  260. shape_expected = (5, 2)
  261. axes_expected = [0, 1]
  262. shape_res, axes_res = _init_nd_shape_and_axes(x, shape, axes)
  263. assert shape_res == shape_expected
  264. assert axes_res == axes_expected
  265. def test_xp_5d_defaults(self, xp):
  266. x = xp.zeros([6, 2, 5, 3, 4])
  267. shape = None
  268. axes = None
  269. shape_expected = (6, 2, 5, 3, 4)
  270. axes_expected = [0, 1, 2, 3, 4]
  271. shape_res, axes_res = _init_nd_shape_and_axes(x, shape, axes)
  272. assert shape_res == shape_expected
  273. assert axes_res == axes_expected
  274. def test_xp_5d_set_shape(self, xp):
  275. x = xp.zeros([6, 2, 5, 3, 4])
  276. shape = [10, -1, -1, 1, 4]
  277. axes = None
  278. shape_expected = (10, 2, 5, 1, 4)
  279. axes_expected = [0, 1, 2, 3, 4]
  280. shape_res, axes_res = _init_nd_shape_and_axes(x, shape, axes)
  281. assert shape_res == shape_expected
  282. assert axes_res == axes_expected
  283. def test_xp_5d_set_axes(self, xp):
  284. x = xp.zeros([6, 2, 5, 3, 4])
  285. shape = None
  286. axes = [4, 1, 2]
  287. shape_expected = (4, 2, 5)
  288. axes_expected = [4, 1, 2]
  289. shape_res, axes_res = _init_nd_shape_and_axes(x, shape, axes)
  290. assert shape_res == shape_expected
  291. assert axes_res == axes_expected
  292. def test_xp_5d_set_shape_axes(self, xp):
  293. x = xp.zeros([6, 2, 5, 3, 4])
  294. shape = [10, -1, 2]
  295. axes = [1, 0, 3]
  296. shape_expected = (10, 6, 2)
  297. axes_expected = [1, 0, 3]
  298. shape_res, axes_res = _init_nd_shape_and_axes(x, shape, axes)
  299. assert shape_res == shape_expected
  300. assert axes_res == axes_expected
  301. def test_shape_axes_subset(self, xp):
  302. x = xp.zeros((2, 3, 4, 5))
  303. shape, axes = _init_nd_shape_and_axes(x, shape=(5, 5, 5), axes=None)
  304. assert shape == (5, 5, 5)
  305. assert axes == [1, 2, 3]
  306. def test_errors(self, xp):
  307. x = xp.zeros(1)
  308. with assert_raises(ValueError, match="axes must be a scalar or "
  309. "iterable of integers"):
  310. _init_nd_shape_and_axes(x, shape=None, axes=[[1, 2], [3, 4]])
  311. with assert_raises(ValueError, match="axes must be a scalar or "
  312. "iterable of integers"):
  313. _init_nd_shape_and_axes(x, shape=None, axes=[1., 2., 3., 4.])
  314. with assert_raises(ValueError,
  315. match="axes exceeds dimensionality of input"):
  316. _init_nd_shape_and_axes(x, shape=None, axes=[1])
  317. with assert_raises(ValueError,
  318. match="axes exceeds dimensionality of input"):
  319. _init_nd_shape_and_axes(x, shape=None, axes=[-2])
  320. with assert_raises(ValueError,
  321. match="all axes must be unique"):
  322. _init_nd_shape_and_axes(x, shape=None, axes=[0, 0])
  323. with assert_raises(ValueError, match="shape must be a scalar or "
  324. "iterable of integers"):
  325. _init_nd_shape_and_axes(x, shape=[[1, 2], [3, 4]], axes=None)
  326. with assert_raises(ValueError, match="shape must be a scalar or "
  327. "iterable of integers"):
  328. _init_nd_shape_and_axes(x, shape=[1., 2., 3., 4.], axes=None)
  329. with assert_raises(ValueError,
  330. match="when given, axes and shape arguments"
  331. " have to be of the same length"):
  332. _init_nd_shape_and_axes(xp.zeros([1, 1, 1, 1]),
  333. shape=[1, 2, 3], axes=[1])
  334. with assert_raises(ValueError,
  335. match="invalid number of data points"
  336. r" \(\[0\]\) specified"):
  337. _init_nd_shape_and_axes(x, shape=[0], axes=None)
  338. with assert_raises(ValueError,
  339. match="invalid number of data points"
  340. r" \(\[-2\]\) specified"):
  341. _init_nd_shape_and_axes(x, shape=-2, axes=None)
  342. class TestFFTShift:
  343. def test_definition(self, xp):
  344. x = xp.asarray([0., 1, 2, 3, 4, -4, -3, -2, -1])
  345. y = xp.asarray([-4., -3, -2, -1, 0, 1, 2, 3, 4])
  346. xp_assert_close(fft.fftshift(x), y)
  347. xp_assert_close(fft.ifftshift(y), x)
  348. x = xp.asarray([0., 1, 2, 3, 4, -5, -4, -3, -2, -1])
  349. y = xp.asarray([-5., -4, -3, -2, -1, 0, 1, 2, 3, 4])
  350. xp_assert_close(fft.fftshift(x), y)
  351. xp_assert_close(fft.ifftshift(y), x)
  352. def test_inverse(self, xp):
  353. for n in [1, 4, 9, 100, 211]:
  354. x = xp.asarray(np.random.random((n,)))
  355. xp_assert_close(fft.ifftshift(fft.fftshift(x)), x)
  356. def test_axes_keyword(self, xp):
  357. freqs = xp.asarray([[0., 1, 2], [3, 4, -4], [-3, -2, -1]])
  358. shifted = xp.asarray([[-1., -3, -2], [2, 0, 1], [-4, 3, 4]])
  359. xp_assert_close(fft.fftshift(freqs, axes=(0, 1)), shifted)
  360. xp_assert_close(fft.fftshift(freqs, axes=0), fft.fftshift(freqs, axes=(0,)))
  361. xp_assert_close(fft.ifftshift(shifted, axes=(0, 1)), freqs)
  362. xp_assert_close(fft.ifftshift(shifted, axes=0),
  363. fft.ifftshift(shifted, axes=(0,)))
  364. xp_assert_close(fft.fftshift(freqs), shifted)
  365. xp_assert_close(fft.ifftshift(shifted), freqs)
  366. def test_uneven_dims(self, xp):
  367. """ Test 2D input, which has uneven dimension sizes """
  368. freqs = xp.asarray([
  369. [0, 1],
  370. [2, 3],
  371. [4, 5]
  372. ], dtype=xp.float64)
  373. # shift in dimension 0
  374. shift_dim0 = xp.asarray([
  375. [4, 5],
  376. [0, 1],
  377. [2, 3]
  378. ], dtype=xp.float64)
  379. xp_assert_close(fft.fftshift(freqs, axes=0), shift_dim0)
  380. xp_assert_close(fft.ifftshift(shift_dim0, axes=0), freqs)
  381. xp_assert_close(fft.fftshift(freqs, axes=(0,)), shift_dim0)
  382. xp_assert_close(fft.ifftshift(shift_dim0, axes=[0]), freqs)
  383. # shift in dimension 1
  384. shift_dim1 = xp.asarray([
  385. [1, 0],
  386. [3, 2],
  387. [5, 4]
  388. ], dtype=xp.float64)
  389. xp_assert_close(fft.fftshift(freqs, axes=1), shift_dim1)
  390. xp_assert_close(fft.ifftshift(shift_dim1, axes=1), freqs)
  391. # shift in both dimensions
  392. shift_dim_both = xp.asarray([
  393. [5, 4],
  394. [1, 0],
  395. [3, 2]
  396. ], dtype=xp.float64)
  397. xp_assert_close(fft.fftshift(freqs, axes=(0, 1)), shift_dim_both)
  398. xp_assert_close(fft.ifftshift(shift_dim_both, axes=(0, 1)), freqs)
  399. xp_assert_close(fft.fftshift(freqs, axes=[0, 1]), shift_dim_both)
  400. xp_assert_close(fft.ifftshift(shift_dim_both, axes=[0, 1]), freqs)
  401. # axes=None (default) shift in all dimensions
  402. xp_assert_close(fft.fftshift(freqs, axes=None), shift_dim_both)
  403. xp_assert_close(fft.ifftshift(shift_dim_both, axes=None), freqs)
  404. xp_assert_close(fft.fftshift(freqs), shift_dim_both)
  405. xp_assert_close(fft.ifftshift(shift_dim_both), freqs)
  406. class TestFFTFreq:
  407. def test_definition(self, xp):
  408. x = xp.asarray([0, 1, 2, 3, 4, -4, -3, -2, -1], dtype=xp.float64)
  409. x2 = xp.asarray([0, 1, 2, 3, 4, -5, -4, -3, -2, -1], dtype=xp.float64)
  410. # default dtype varies across backends
  411. y = 9 * fft.fftfreq(9, xp=xp)
  412. xp_assert_close(y, x, check_dtype=False, check_namespace=True)
  413. y = 9 * xp.pi * fft.fftfreq(9, xp.pi, xp=xp)
  414. xp_assert_close(y, x, check_dtype=False)
  415. y = 10 * fft.fftfreq(10, xp=xp)
  416. xp_assert_close(y, x2, check_dtype=False)
  417. y = 10 * xp.pi * fft.fftfreq(10, xp.pi, xp=xp)
  418. xp_assert_close(y, x2, check_dtype=False)
  419. def test_device(self, xp, devices):
  420. for d in devices:
  421. y = fft.fftfreq(9, xp=xp, device=d)
  422. x = xp.empty(0, device=d)
  423. assert xp_device(y) == xp_device(x)
  424. class TestRFFTFreq:
  425. def test_definition(self, xp):
  426. x = xp.asarray([0, 1, 2, 3, 4], dtype=xp.float64)
  427. x2 = xp.asarray([0, 1, 2, 3, 4, 5], dtype=xp.float64)
  428. # default dtype varies across backends
  429. y = 9 * fft.rfftfreq(9, xp=xp)
  430. xp_assert_close(y, x, check_dtype=False, check_namespace=True)
  431. y = 9 * xp.pi * fft.rfftfreq(9, xp.pi, xp=xp)
  432. xp_assert_close(y, x, check_dtype=False)
  433. y = 10 * fft.rfftfreq(10, xp=xp)
  434. xp_assert_close(y, x2, check_dtype=False)
  435. y = 10 * xp.pi * fft.rfftfreq(10, xp.pi, xp=xp)
  436. xp_assert_close(y, x2, check_dtype=False)
  437. def test_device(self, xp, devices):
  438. for d in devices:
  439. y = fft.rfftfreq(9, xp=xp, device=d)
  440. x = xp.empty(0, device=d)
  441. assert xp_device(y) == xp_device(x)