test_bsplines.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # pylint: disable=missing-docstring
  2. import math
  3. import numpy as np
  4. from scipy._lib._array_api import (
  5. assert_almost_equal, xp_assert_close, xp_assert_equal, make_xp_test_case,
  6. xp_default_dtype, array_namespace, _xp_copy_to_numpy
  7. )
  8. import pytest
  9. from pytest import raises
  10. from scipy import signal
  11. skip_xp_backends = pytest.mark.skip_xp_backends
  12. xfail_xp_backends = pytest.mark.xfail_xp_backends
  13. lazy_xp_modules = [signal]
  14. class TestBSplines:
  15. """Test behaviors of B-splines. Some of the values tested against were
  16. returned as of SciPy 1.1.0 and are included for regression testing
  17. purposes. Others (at integer points) are compared to theoretical
  18. expressions (cf. Unser, Aldroubi, Eden, IEEE TSP 1993, Table 1)."""
  19. @make_xp_test_case(signal.spline_filter)
  20. def test_spline_filter(self, xp):
  21. rng = np.random.RandomState(12457)
  22. # Test the type-error branch
  23. raises(TypeError, signal.spline_filter, xp.asarray([0]), 0)
  24. # Test the real branch
  25. data_array_real = rng.rand(12, 12)
  26. # make the magnitude exceed 1, and make some negative
  27. data_array_real = 10*(1-2*data_array_real)
  28. data_array_real = xp.asarray(data_array_real)
  29. result_array_real = xp.asarray(
  30. [[-.463312621, 8.33391222, .697290949, 5.28390836,
  31. 5.92066474, 6.59452137, 9.84406950, -8.78324188,
  32. 7.20675750, -8.17222994, -4.38633345, 9.89917069],
  33. [2.67755154, 6.24192170, -3.15730578, 9.87658581,
  34. -9.96930425, 3.17194115, -4.50919947, 5.75423446,
  35. 9.65979824, -8.29066885, .971416087, -2.38331897],
  36. [-7.08868346, 4.89887705, -1.37062289, 7.70705838,
  37. 2.51526461, 3.65885497, 5.16786604, -8.77715342e-03,
  38. 4.10533325, 9.04761993, -.577960351, 9.86382519],
  39. [-4.71444301, -1.68038985, 2.84695116, 1.14315938,
  40. -3.17127091, 1.91830461, 7.13779687, -5.35737482,
  41. -9.66586425, -9.87717456, 9.93160672, 4.71948144],
  42. [9.49551194, -1.92958436, 6.25427993, -9.05582911,
  43. 3.97562282, 7.68232426, -1.04514824, -5.86021443,
  44. -8.43007451, 5.47528997, 2.06330736, -8.65968112],
  45. [-8.91720100, 8.87065356, 3.76879937, 2.56222894,
  46. -.828387146, 8.72288903, 6.42474741, -6.84576083,
  47. 9.94724115, 6.90665380, -6.61084494, -9.44907391],
  48. [9.25196790, -.774032030, 7.05371046, -2.73505725,
  49. 2.53953305, -1.82889155, 2.95454824, -1.66362046,
  50. 5.72478916, -3.10287679, 1.54017123, -7.87759020],
  51. [-3.98464539, -2.44316992, -1.12708657, 1.01725672,
  52. -8.89294671, -5.42145629, -6.16370321, 2.91775492,
  53. 9.64132208, .702499998, -2.02622392, 1.56308431],
  54. [-2.22050773, 7.89951554, 5.98970713, -7.35861835,
  55. 5.45459283, -7.76427957, 3.67280490, -4.05521315,
  56. 4.51967507, -3.22738749, -3.65080177, 3.05630155],
  57. [-6.21240584, -.296796126, -8.34800163, 9.21564563,
  58. -3.61958784, -4.77120006, -3.99454057, 1.05021988e-03,
  59. -6.95982829, 6.04380797, 8.43181250, -2.71653339],
  60. [1.19638037, 6.99718842e-02, 6.72020394, -2.13963198,
  61. 3.75309875, -5.70076744, 5.92143551, -7.22150575,
  62. -3.77114594, -1.11903194, -5.39151466, 3.06620093],
  63. [9.86326886, 1.05134482, -7.75950607, -3.64429655,
  64. 7.81848957, -9.02270373, 3.73399754, -4.71962549,
  65. -7.71144306, 3.78263161, 6.46034818, -4.43444731]], dtype=xp.float64)
  66. xp_assert_close(signal.spline_filter(data_array_real, 0),
  67. result_array_real)
  68. @make_xp_test_case(signal.spline_filter)
  69. def test_spline_filter_complex(self, xp):
  70. rng = np.random.RandomState(12457)
  71. data_array_complex = rng.rand(7, 7) + rng.rand(7, 7)*1j
  72. # make the magnitude exceed 1, and make some negative
  73. data_array_complex = 10*(1+1j-2*data_array_complex)
  74. data_array_complex = xp.asarray(data_array_complex)
  75. result_array_complex = xp.asarray(
  76. [[-4.61489230e-01-1.92994022j, 8.33332443+6.25519943j,
  77. 6.96300745e-01-9.05576038j, 5.28294849+3.97541356j,
  78. 5.92165565+7.68240595j, 6.59493160-1.04542804j,
  79. 9.84503460-5.85946894j],
  80. [-8.78262329-8.4295969j, 7.20675516+5.47528982j,
  81. -8.17223072+2.06330729j, -4.38633347-8.65968037j,
  82. 9.89916801-8.91720295j, 2.67755103+8.8706522j,
  83. 6.24192142+3.76879835j],
  84. [-3.15627527+2.56303072j, 9.87658501-0.82838702j,
  85. -9.96930313+8.72288895j, 3.17193985+6.42474651j,
  86. -4.50919819-6.84576082j, 5.75423431+9.94723988j,
  87. 9.65979767+6.90665293j],
  88. [-8.28993416-6.61064005j, 9.71416473e-01-9.44907284j,
  89. -2.38331890+9.25196648j, -7.08868170-0.77403212j,
  90. 4.89887714+7.05371094j, -1.37062311-2.73505688j,
  91. 7.70705748+2.5395329j],
  92. [2.51528406-1.82964492j, 3.65885472+2.95454836j,
  93. 5.16786575-1.66362023j, -8.77737999e-03+5.72478867j,
  94. 4.10533333-3.10287571j, 9.04761887+1.54017115j,
  95. -5.77960968e-01-7.87758923j],
  96. [9.86398506-3.98528528j, -4.71444130-2.44316983j,
  97. -1.68038976-1.12708664j, 2.84695053+1.01725709j,
  98. 1.14315915-8.89294529j, -3.17127085-5.42145538j,
  99. 1.91830420-6.16370344j],
  100. [7.13875294+2.91851187j, -5.35737514+9.64132309j,
  101. -9.66586399+0.70250005j, -9.87717438-2.0262239j,
  102. 9.93160629+1.5630846j, 4.71948051-2.22050714j,
  103. 9.49550819+7.8995142j]], dtype=xp.complex128)
  104. # FIXME: for complex types, the computations are done in
  105. # single precision (reason unclear). When this is changed,
  106. # this test needs updating.
  107. xp_assert_close(signal.spline_filter(data_array_complex, 0),
  108. result_array_complex, rtol=1e-6)
  109. @make_xp_test_case(signal.gauss_spline)
  110. def test_gauss_spline(self, xp):
  111. assert math.isclose(signal.gauss_spline(0, 0), 1.381976597885342)
  112. xp_assert_close(signal.gauss_spline(xp.asarray([1.]), 1),
  113. xp.asarray([0.04865217]), atol=1e-9
  114. )
  115. @skip_xp_backends(np_only=True, reason="deliberate: array-likes are accepted")
  116. @make_xp_test_case(signal.gauss_spline)
  117. def test_gauss_spline_list(self, xp):
  118. # regression test for gh-12152 (accept array_like)
  119. knots = [-1.0, 0.0, -1.0]
  120. assert_almost_equal(signal.gauss_spline(knots, 3),
  121. np.asarray([0.15418033, 0.6909883, 0.15418033])
  122. )
  123. @make_xp_test_case(signal.cspline1d)
  124. def test_cspline1d(self, xp):
  125. xp_assert_equal(signal.cspline1d(xp.asarray([0])),
  126. xp.asarray([0.], dtype=xp.float64))
  127. c1d = xp.asarray([1.21037185, 1.86293902, 2.98834059, 4.11660378,
  128. 4.78893826], dtype=xp.float64)
  129. # test lamda != 0
  130. xp_assert_close(signal.cspline1d(xp.asarray([1., 2, 3, 4, 5]), 1), c1d)
  131. c1d0 = xp.asarray([0.78683946, 2.05333735, 2.99981113, 3.94741812,
  132. 5.21051638], dtype=xp.float64)
  133. xp_assert_close(signal.cspline1d(xp.asarray([1., 2, 3, 4, 5])), c1d0)
  134. @make_xp_test_case(signal.qspline1d)
  135. def test_qspline1d(self, xp):
  136. xp_assert_equal(signal.qspline1d(xp.asarray([0])),
  137. xp.asarray([0.], dtype=xp.float64))
  138. # test lamda != 0
  139. raises(ValueError, signal.qspline1d, xp.asarray([1., 2, 3, 4, 5]), 1.)
  140. raises(ValueError, signal.qspline1d, xp.asarray([1., 2, 3, 4, 5]), -1.)
  141. q1d0 = xp.asarray([0.85350007, 2.02441743, 2.99999534, 3.97561055,
  142. 5.14634135], dtype=xp.float64)
  143. xp_assert_close(
  144. signal.qspline1d(xp.asarray([1., 2, 3, 4, 5], dtype=xp.float64)), q1d0
  145. )
  146. @xfail_xp_backends("cupy", reason="https://github.com/cupy/cupy/pull/9484")
  147. @make_xp_test_case(signal.cspline1d_eval)
  148. def test_cspline1d_eval(self, xp):
  149. r = signal.cspline1d_eval(xp.asarray([0., 0], dtype=xp.float64),
  150. xp.asarray([0.], dtype=xp.float64))
  151. xp_assert_close(r, xp.asarray([0.], dtype=xp.float64))
  152. r = signal.cspline1d_eval(xp.asarray([1., 0, 1], dtype=xp.float64),
  153. xp.asarray([], dtype=xp.float64))
  154. xp_assert_equal(r, xp.asarray([], dtype=xp.float64))
  155. # Test case for newx that gets filtered down to empty
  156. r = signal.cspline1d_eval(xp.asarray([1.0, 0, 1], dtype=xp.float64),
  157. xp.asarray([-1.0], dtype=xp.float64))
  158. xp_assert_close(r, xp.asarray([0.33333333], dtype=xp.float64))
  159. x = [-3, -2, -1, 0, 1, 2, 3, 4, 5, 6]
  160. dx = x[1] - x[0]
  161. newx = [-6., -5.5, -5., -4.5, -4., -3.5, -3., -2.5, -2., -1.5, -1.,
  162. -0.5, 0., 0.5, 1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6.,
  163. 6.5, 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.,
  164. 12.5]
  165. y = xp.asarray([4.216, 6.864, 3.514, 6.203, 6.759, 7.433, 7.874, 5.879,
  166. 1.396, 4.094])
  167. cj = xp.asarray(signal.cspline1d(_xp_copy_to_numpy(y)))
  168. newy = xp.asarray([6.203, 4.41570658, 3.514, 5.16924703, 6.864, 6.04643068,
  169. 4.21600281, 6.04643068, 6.864, 5.16924703, 3.514,
  170. 4.41570658, 6.203, 6.80717667, 6.759, 6.98971173, 7.433,
  171. 7.79560142, 7.874, 7.41525761, 5.879, 3.18686814, 1.396,
  172. 2.24889482, 4.094, 2.24889482, 1.396, 3.18686814, 5.879,
  173. 7.41525761, 7.874, 7.79560142, 7.433, 6.98971173, 6.759,
  174. 6.80717667, 6.203, 4.41570658], dtype=xp.float64)
  175. xp_assert_close(
  176. signal.cspline1d_eval(cj, xp.asarray(newx), dx=dx, x0=x[0]), newy
  177. )
  178. with pytest.raises(ValueError,
  179. match="Spline coefficients 'cj' must not be empty."):
  180. signal.cspline1d_eval(xp.asarray([], dtype=xp.float64),
  181. xp.asarray([0.0], dtype=xp.float64))
  182. @xfail_xp_backends("cupy", reason="https://github.com/cupy/cupy/pull/9484")
  183. @make_xp_test_case(signal.qspline1d_eval)
  184. def test_qspline1d_eval(self, xp):
  185. xp_assert_close(signal.qspline1d_eval(xp.asarray([0., 0]), xp.asarray([0.])),
  186. xp.asarray([0.])
  187. )
  188. xp_assert_equal(signal.qspline1d_eval(xp.asarray([1., 0, 1]), xp.asarray([])),
  189. xp.asarray([])
  190. )
  191. # Test case for newx that gets filtered down to empty
  192. r = signal.qspline1d_eval(xp.asarray([1.0, 0, 1], dtype=xp.float64),
  193. xp.asarray([-1.0], dtype=xp.float64))
  194. xp_assert_equal(r, xp.asarray([0.25], dtype=xp.float64))
  195. x = [-3, -2, -1, 0, 1, 2, 3, 4, 5, 6]
  196. dx = x[1] - x[0]
  197. newx = [-6., -5.5, -5., -4.5, -4., -3.5, -3., -2.5, -2., -1.5, -1.,
  198. -0.5, 0., 0.5, 1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6.,
  199. 6.5, 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.,
  200. 12.5]
  201. y = xp.asarray([4.216, 6.864, 3.514, 6.203, 6.759, 7.433, 7.874, 5.879,
  202. 1.396, 4.094])
  203. cj = signal.qspline1d(y)
  204. newy = xp.asarray([6.203, 4.49418159, 3.514, 5.18390821, 6.864, 5.91436915,
  205. 4.21600002, 5.91436915, 6.864, 5.18390821, 3.514,
  206. 4.49418159, 6.203, 6.71900226, 6.759, 7.03980488, 7.433,
  207. 7.81016848, 7.874, 7.32718426, 5.879, 3.23872593, 1.396,
  208. 2.34046013, 4.094, 2.34046013, 1.396, 3.23872593, 5.879,
  209. 7.32718426, 7.874, 7.81016848, 7.433, 7.03980488, 6.759,
  210. 6.71900226, 6.203, 4.49418159], dtype=xp.float64)
  211. r = signal.qspline1d_eval(
  212. cj, xp.asarray(newx, dtype=xp.float64), dx=dx, x0=x[0]
  213. )
  214. xp_assert_close(r, newy)
  215. with pytest.raises(ValueError,
  216. match="Spline coefficients 'cj' must not be empty."):
  217. signal.qspline1d_eval(xp.asarray([], dtype=xp.float64),
  218. xp.asarray([0.0], dtype=xp.float64))
  219. # i/o dtypes with scipy 1.9.1, likely fixed by backwards compat
  220. sepfir_dtype_map = {np.uint8: np.float32, int: np.float64,
  221. np.float32: np.float32, float: float,
  222. np.complex64: np.complex64, complex: complex}
  223. @skip_xp_backends(np_only=True)
  224. class TestSepfir2d:
  225. def test_sepfir2d_invalid_filter(self, xp):
  226. filt = xp.asarray([1.0, 2.0, 4.0, 2.0, 1.0])
  227. image = np.random.rand(7, 9)
  228. image = xp.asarray(image)
  229. # No error for odd lengths
  230. signal.sepfir2d(image, filt, filt[2:])
  231. # Row or column filter must be odd
  232. with pytest.raises(ValueError, match="odd length"):
  233. signal.sepfir2d(image, filt, filt[1:])
  234. with pytest.raises(ValueError, match="odd length"):
  235. signal.sepfir2d(image, filt[1:], filt)
  236. # Filters must be 1-dimensional
  237. with pytest.raises(ValueError, match="object too deep"):
  238. signal.sepfir2d(image, xp.reshape(filt, (1, -1)), filt)
  239. with pytest.raises(ValueError, match="object too deep"):
  240. signal.sepfir2d(image, filt, xp.reshape(filt, (1, -1)))
  241. def test_sepfir2d_invalid_image(self, xp):
  242. filt = xp.asarray([1.0, 2.0, 4.0, 2.0, 1.0])
  243. image = np.random.rand(8, 8)
  244. image = xp.asarray(image)
  245. # Image must be 2 dimensional
  246. with pytest.raises(ValueError, match="object too deep"):
  247. signal.sepfir2d(xp.reshape(image, (4, 4, 4)), filt, filt)
  248. with pytest.raises(ValueError, match="object of too small depth"):
  249. signal.sepfir2d(image[0, :], filt, filt)
  250. @pytest.mark.parametrize('dtyp',
  251. [np.uint8, int, np.float32, float, np.complex64, complex]
  252. )
  253. def test_simple(self, dtyp, xp):
  254. # test values on a paper-and-pencil example
  255. a = np.array([[1, 2, 3, 3, 2, 1],
  256. [1, 2, 3, 3, 2, 1],
  257. [1, 2, 3, 3, 2, 1],
  258. [1, 2, 3, 3, 2, 1]], dtype=dtyp)
  259. h1 = [0.5, 1, 0.5]
  260. h2 = [1]
  261. result = signal.sepfir2d(a, h1, h2)
  262. dt = sepfir_dtype_map[dtyp]
  263. expected = np.asarray([[2.5, 4. , 5.5, 5.5, 4. , 2.5],
  264. [2.5, 4. , 5.5, 5.5, 4. , 2.5],
  265. [2.5, 4. , 5.5, 5.5, 4. , 2.5],
  266. [2.5, 4. , 5.5, 5.5, 4. , 2.5]], dtype=dt)
  267. xp_assert_close(result, expected, atol=1e-16)
  268. result = signal.sepfir2d(a, h2, h1)
  269. expected = np.asarray([[2., 4., 6., 6., 4., 2.],
  270. [2., 4., 6., 6., 4., 2.],
  271. [2., 4., 6., 6., 4., 2.],
  272. [2., 4., 6., 6., 4., 2.]], dtype=dt)
  273. xp_assert_close(result, expected, atol=1e-16)
  274. @skip_xp_backends(np_only=True, reason="TODO: convert this test")
  275. @pytest.mark.parametrize('dtyp',
  276. [np.uint8, int, np.float32, float, np.complex64, complex]
  277. )
  278. def test_strided(self, dtyp, xp):
  279. a = np.array([[1, 2, 3, 3, 2, 1, 1, 2, 3],
  280. [1, 2, 3, 3, 2, 1, 1, 2, 3],
  281. [1, 2, 3, 3, 2, 1, 1, 2, 3],
  282. [1, 2, 3, 3, 2, 1, 1, 2, 3]])
  283. h1, h2 = [0.5, 1, 0.5], [1]
  284. result_strided = signal.sepfir2d(a[:, ::2], h1, h2)
  285. result_contig = signal.sepfir2d(a[:, ::2].copy(), h1, h2)
  286. xp_assert_close(result_strided, result_contig, atol=1e-15)
  287. assert result_strided.dtype == result_contig.dtype
  288. @skip_xp_backends(np_only=True, reason="TODO: convert this test")
  289. @pytest.mark.xfail(reason="XXX: filt.size > image.shape: flaky")
  290. def test_sepfir2d_strided_2(self, xp):
  291. # XXX: this test is flaky: fails on some reruns, with
  292. # result[0, 1] and result[1, 1] being ~1e+224.
  293. filt = np.array([1.0, 2.0, 4.0, 2.0, 1.0, 3.0, 2.0])
  294. image = np.random.rand(4, 4)
  295. expected = np.asarray([[36.018162, 30.239061, 38.71187 , 43.878183],
  296. [38.180999, 35.824583, 43.525247, 43.874945],
  297. [43.269533, 40.834018, 46.757772, 44.276423],
  298. [49.120928, 39.681844, 43.596067, 45.085854]])
  299. xp_assert_close(signal.sepfir2d(image, filt, filt[::3]), expected)
  300. @skip_xp_backends(np_only=True, reason="TODO: convert this test")
  301. @pytest.mark.xfail(reason="XXX: flaky. pointers OOB on some platforms")
  302. @pytest.mark.fail_asan
  303. @pytest.mark.parametrize('dtyp',
  304. [np.uint8, int, np.float32, float, np.complex64, complex]
  305. )
  306. def test_sepfir2d_strided_3(self, dtyp, xp):
  307. # NB: 'image' and 'filt' dtypes match here. Otherwise we can run into
  308. # unsafe casting errors for many combinations. Historically, dtype handling
  309. # in `sepfir2d` is a tad baroque; fixing it is an enhancement.
  310. filt = np.array([1, 2, 4, 2, 1, 3, 2], dtype=dtyp)
  311. image = np.asarray([[0, 3, 0, 1, 2],
  312. [2, 2, 3, 3, 3],
  313. [0, 1, 3, 0, 3],
  314. [2, 3, 0, 1, 3],
  315. [3, 3, 2, 1, 2]], dtype=dtyp)
  316. expected = [[123., 101., 91., 136., 127.],
  317. [133., 125., 126., 152., 160.],
  318. [136., 137., 150., 162., 177.],
  319. [133., 124., 132., 148., 147.],
  320. [173., 158., 152., 164., 141.]]
  321. expected = np.asarray(expected)
  322. result = signal.sepfir2d(image, filt, filt[::3])
  323. xp_assert_close(result, expected, atol=1e-15)
  324. assert result.dtype == sepfir_dtype_map[dtyp]
  325. expected = [[22., 35., 41., 31., 47.],
  326. [27., 39., 48., 47., 55.],
  327. [33., 42., 49., 53., 59.],
  328. [39., 44., 41., 36., 48.],
  329. [67., 62., 47., 34., 46.]]
  330. expected = np.asarray(expected)
  331. result = signal.sepfir2d(image, filt[::3], filt[::3])
  332. xp_assert_close(result, expected, atol=1e-15)
  333. assert result.dtype == sepfir_dtype_map[dtyp]
  334. @make_xp_test_case(signal.cspline2d)
  335. def test_cspline2d(xp):
  336. rng = np.random.RandomState(181819142)
  337. image = rng.rand(71, 73)
  338. image = xp.asarray(image, dtype=xp_default_dtype(xp))
  339. result = signal.cspline2d(image, 8.0)
  340. assert array_namespace(result) == xp
  341. @make_xp_test_case(signal.qspline2d)
  342. def test_qspline2d(xp):
  343. rng = np.random.RandomState(181819143)
  344. image = rng.rand(71, 73)
  345. image = xp.asarray(image, dtype=xp_default_dtype(xp))
  346. result = signal.qspline2d(image)
  347. assert array_namespace(result) == xp