test_wavelets.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import numpy as np
  2. from numpy.testing import assert_array_equal, assert_array_almost_equal
  3. import scipy.signal._wavelets as wavelets
  4. class TestWavelets:
  5. def test_ricker(self):
  6. w = wavelets._ricker(1.0, 1)
  7. expected = 2 / (np.sqrt(3 * 1.0) * (np.pi ** 0.25))
  8. assert_array_equal(w, expected)
  9. lengths = [5, 11, 15, 51, 101]
  10. for length in lengths:
  11. w = wavelets._ricker(length, 1.0)
  12. assert len(w) == length
  13. max_loc = np.argmax(w)
  14. assert max_loc == (length // 2)
  15. points = 100
  16. w = wavelets._ricker(points, 2.0)
  17. half_vec = np.arange(0, points // 2)
  18. # Wavelet should be symmetric
  19. assert_array_almost_equal(w[half_vec], w[-(half_vec + 1)])
  20. # Check zeros
  21. aas = [5, 10, 15, 20, 30]
  22. points = 99
  23. for a in aas:
  24. w = wavelets._ricker(points, a)
  25. vec = np.arange(0, points) - (points - 1.0) / 2
  26. exp_zero1 = np.argmin(np.abs(vec - a))
  27. exp_zero2 = np.argmin(np.abs(vec + a))
  28. assert_array_almost_equal(w[exp_zero1], 0)
  29. assert_array_almost_equal(w[exp_zero2], 0)
  30. def test_cwt(self):
  31. widths = [1.0]
  32. def delta_wavelet(s, t):
  33. return np.array([1])
  34. len_data = 100
  35. test_data = np.sin(np.pi * np.arange(0, len_data) / 10.0)
  36. # Test delta function input gives same data as output
  37. cwt_dat = wavelets._cwt(test_data, delta_wavelet, widths)
  38. assert cwt_dat.shape == (len(widths), len_data)
  39. assert_array_almost_equal(test_data, cwt_dat.flatten())
  40. # Check proper shape on output
  41. widths = [1, 3, 4, 5, 10]
  42. cwt_dat = wavelets._cwt(test_data, wavelets._ricker, widths)
  43. assert cwt_dat.shape == (len(widths), len_data)
  44. widths = [len_data * 10]
  45. # Note: this wavelet isn't defined quite right, but is fine for this test
  46. def flat_wavelet(l, w):
  47. return np.full(w, 1 / w)
  48. cwt_dat = wavelets._cwt(test_data, flat_wavelet, widths)
  49. assert_array_almost_equal(cwt_dat, np.mean(test_data))