test_splines.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. """Tests for spline filtering."""
  2. import pytest
  3. import numpy as np
  4. from scipy._lib._array_api import assert_almost_equal, make_xp_test_case
  5. from scipy import ndimage
  6. def get_spline_knot_values(order):
  7. """Knot values to the right of a B-spline's center."""
  8. knot_values = {0: [1],
  9. 1: [1],
  10. 2: [6, 1],
  11. 3: [4, 1],
  12. 4: [230, 76, 1],
  13. 5: [66, 26, 1]}
  14. return knot_values[order]
  15. def make_spline_knot_matrix(xp, n, order, mode='mirror'):
  16. """Matrix to invert to find the spline coefficients."""
  17. knot_values = get_spline_knot_values(order)
  18. # NB: do computations with numpy, convert to xp as the last step only
  19. matrix = np.zeros((n, n))
  20. for diag, knot_value in enumerate(knot_values):
  21. indices = np.arange(diag, n)
  22. if diag == 0:
  23. matrix[indices, indices] = knot_value
  24. else:
  25. matrix[indices, indices - diag] = knot_value
  26. matrix[indices - diag, indices] = knot_value
  27. knot_values_sum = knot_values[0] + 2 * sum(knot_values[1:])
  28. if mode == 'mirror':
  29. start, step = 1, 1
  30. elif mode == 'reflect':
  31. start, step = 0, 1
  32. elif mode == 'grid-wrap':
  33. start, step = -1, -1
  34. else:
  35. raise ValueError(f'unsupported mode {mode}')
  36. for row in range(len(knot_values) - 1):
  37. for idx, knot_value in enumerate(knot_values[row + 1:]):
  38. matrix[row, start + step*idx] += knot_value
  39. matrix[-row - 1, -start - 1 - step*idx] += knot_value
  40. return xp.asarray(matrix / knot_values_sum)
  41. @make_xp_test_case(ndimage.spline_filter1d)
  42. @pytest.mark.parametrize('order', [0, 1, 2, 3, 4, 5])
  43. @pytest.mark.parametrize('mode', ['mirror', 'grid-wrap', 'reflect'])
  44. def test_spline_filter_vs_matrix_solution(order, mode, xp):
  45. n = 100
  46. eye = xp.eye(n, dtype=xp.float64)
  47. spline_filter_axis_0 = ndimage.spline_filter1d(eye, axis=0, order=order,
  48. mode=mode)
  49. spline_filter_axis_1 = ndimage.spline_filter1d(eye, axis=1, order=order,
  50. mode=mode)
  51. matrix = make_spline_knot_matrix(xp, n, order, mode=mode)
  52. assert_almost_equal(eye, spline_filter_axis_0 @ matrix)
  53. assert_almost_equal(eye, spline_filter_axis_1 @ matrix.T)