| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- """Tests for spline filtering."""
- import pytest
- import numpy as np
- from scipy._lib._array_api import assert_almost_equal, make_xp_test_case
- from scipy import ndimage
- def get_spline_knot_values(order):
- """Knot values to the right of a B-spline's center."""
- knot_values = {0: [1],
- 1: [1],
- 2: [6, 1],
- 3: [4, 1],
- 4: [230, 76, 1],
- 5: [66, 26, 1]}
- return knot_values[order]
- def make_spline_knot_matrix(xp, n, order, mode='mirror'):
- """Matrix to invert to find the spline coefficients."""
- knot_values = get_spline_knot_values(order)
- # NB: do computations with numpy, convert to xp as the last step only
- matrix = np.zeros((n, n))
- for diag, knot_value in enumerate(knot_values):
- indices = np.arange(diag, n)
- if diag == 0:
- matrix[indices, indices] = knot_value
- else:
- matrix[indices, indices - diag] = knot_value
- matrix[indices - diag, indices] = knot_value
- knot_values_sum = knot_values[0] + 2 * sum(knot_values[1:])
- if mode == 'mirror':
- start, step = 1, 1
- elif mode == 'reflect':
- start, step = 0, 1
- elif mode == 'grid-wrap':
- start, step = -1, -1
- else:
- raise ValueError(f'unsupported mode {mode}')
- for row in range(len(knot_values) - 1):
- for idx, knot_value in enumerate(knot_values[row + 1:]):
- matrix[row, start + step*idx] += knot_value
- matrix[-row - 1, -start - 1 - step*idx] += knot_value
- return xp.asarray(matrix / knot_values_sum)
- @make_xp_test_case(ndimage.spline_filter1d)
- @pytest.mark.parametrize('order', [0, 1, 2, 3, 4, 5])
- @pytest.mark.parametrize('mode', ['mirror', 'grid-wrap', 'reflect'])
- def test_spline_filter_vs_matrix_solution(order, mode, xp):
- n = 100
- eye = xp.eye(n, dtype=xp.float64)
- spline_filter_axis_0 = ndimage.spline_filter1d(eye, axis=0, order=order,
- mode=mode)
- spline_filter_axis_1 = ndimage.spline_filter1d(eye, axis=1, order=order,
- mode=mode)
- matrix = make_spline_knot_matrix(xp, n, order, mode=mode)
- assert_almost_equal(eye, spline_filter_axis_0 @ matrix)
- assert_almost_equal(eye, spline_filter_axis_1 @ matrix.T)
|