test_waveforms.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. import numpy as np
  2. from pytest import raises as assert_raises
  3. from scipy._lib._array_api import (
  4. assert_almost_equal, xp_assert_equal, xp_assert_close
  5. )
  6. import scipy.signal._waveforms as waveforms
  7. # These chirp_* functions are the instantaneous frequencies of the signals
  8. # returned by chirp().
  9. def chirp_linear(t, f0, f1, t1):
  10. f = f0 + (f1 - f0) * t / t1
  11. return f
  12. def chirp_quadratic(t, f0, f1, t1, vertex_zero=True):
  13. if vertex_zero:
  14. f = f0 + (f1 - f0) * t**2 / t1**2
  15. else:
  16. f = f1 - (f1 - f0) * (t1 - t)**2 / t1**2
  17. return f
  18. def chirp_geometric(t, f0, f1, t1):
  19. f = f0 * (f1/f0)**(t/t1)
  20. return f
  21. def chirp_hyperbolic(t, f0, f1, t1):
  22. f = f0*f1*t1 / ((f0 - f1)*t + f1*t1)
  23. return f
  24. def compute_frequency(t, theta):
  25. """
  26. Compute theta'(t)/(2*pi), where theta'(t) is the derivative of theta(t).
  27. """
  28. # Assume theta and t are 1-D NumPy arrays.
  29. # Assume that t is uniformly spaced.
  30. dt = t[1] - t[0]
  31. f = np.diff(theta)/(2*np.pi) / dt
  32. tf = 0.5*(t[1:] + t[:-1])
  33. return tf, f
  34. class TestChirp:
  35. def test_linear_at_zero(self):
  36. w = waveforms.chirp(t=0, f0=1.0, f1=2.0, t1=1.0, method='linear')
  37. assert_almost_equal(w, 1.0)
  38. def test_linear_freq_01(self):
  39. method = 'linear'
  40. f0 = 1.0
  41. f1 = 2.0
  42. t1 = 1.0
  43. t = np.linspace(0, t1, 100)
  44. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  45. tf, f = compute_frequency(t, phase)
  46. abserr = np.max(np.abs(f - chirp_linear(tf, f0, f1, t1)))
  47. assert abserr < 1e-6
  48. def test_linear_freq_02(self):
  49. method = 'linear'
  50. f0 = 200.0
  51. f1 = 100.0
  52. t1 = 10.0
  53. t = np.linspace(0, t1, 100)
  54. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  55. tf, f = compute_frequency(t, phase)
  56. abserr = np.max(np.abs(f - chirp_linear(tf, f0, f1, t1)))
  57. assert abserr < 1e-6
  58. def test_linear_complex_power(self):
  59. method = 'linear'
  60. f0 = 1.0
  61. f1 = 2.0
  62. t1 = 1.0
  63. t = np.linspace(0, t1, 100)
  64. w_real = waveforms.chirp(t, f0, t1, f1, method, complex=False)
  65. w_complex = waveforms.chirp(t, f0, t1, f1, method, complex=True)
  66. w_pwr_r = np.var(w_real)
  67. w_pwr_c = np.var(w_complex)
  68. # Making sure that power of the real part is not affected with
  69. # complex conversion operation
  70. err = w_pwr_r - np.real(w_pwr_c)
  71. assert(err < 1e-6)
  72. def test_linear_complex_at_zero(self):
  73. w = waveforms.chirp(t=0, f0=-10.0, f1=1.0, t1=1.0, method='linear',
  74. complex=True)
  75. xp_assert_close(w, 1.0+0.0j) # dtype must match
  76. def test_quadratic_at_zero(self):
  77. w = waveforms.chirp(t=0, f0=1.0, f1=2.0, t1=1.0, method='quadratic')
  78. assert_almost_equal(w, 1.0)
  79. def test_quadratic_at_zero2(self):
  80. w = waveforms.chirp(t=0, f0=1.0, f1=2.0, t1=1.0, method='quadratic',
  81. vertex_zero=False)
  82. assert_almost_equal(w, 1.0)
  83. def test_quadratic_complex_at_zero(self):
  84. w = waveforms.chirp(t=0, f0=-1.0, f1=2.0, t1=1.0, method='quadratic',
  85. complex=True)
  86. xp_assert_close(w, 1.0+0j)
  87. def test_quadratic_freq_01(self):
  88. method = 'quadratic'
  89. f0 = 1.0
  90. f1 = 2.0
  91. t1 = 1.0
  92. t = np.linspace(0, t1, 2000)
  93. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  94. tf, f = compute_frequency(t, phase)
  95. abserr = np.max(np.abs(f - chirp_quadratic(tf, f0, f1, t1)))
  96. assert abserr < 1e-6
  97. def test_quadratic_freq_02(self):
  98. method = 'quadratic'
  99. f0 = 20.0
  100. f1 = 10.0
  101. t1 = 10.0
  102. t = np.linspace(0, t1, 2000)
  103. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  104. tf, f = compute_frequency(t, phase)
  105. abserr = np.max(np.abs(f - chirp_quadratic(tf, f0, f1, t1)))
  106. assert abserr < 1e-6
  107. def test_logarithmic_at_zero(self):
  108. w = waveforms.chirp(t=0, f0=1.0, f1=2.0, t1=1.0, method='logarithmic')
  109. assert_almost_equal(w, 1.0)
  110. def test_logarithmic_freq_01(self):
  111. method = 'logarithmic'
  112. f0 = 1.0
  113. f1 = 2.0
  114. t1 = 1.0
  115. t = np.linspace(0, t1, 10000)
  116. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  117. tf, f = compute_frequency(t, phase)
  118. abserr = np.max(np.abs(f - chirp_geometric(tf, f0, f1, t1)))
  119. assert abserr < 1e-6
  120. def test_logarithmic_freq_02(self):
  121. method = 'logarithmic'
  122. f0 = 200.0
  123. f1 = 100.0
  124. t1 = 10.0
  125. t = np.linspace(0, t1, 10000)
  126. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  127. tf, f = compute_frequency(t, phase)
  128. abserr = np.max(np.abs(f - chirp_geometric(tf, f0, f1, t1)))
  129. assert abserr < 1e-6
  130. def test_logarithmic_freq_03(self):
  131. method = 'logarithmic'
  132. f0 = 100.0
  133. f1 = 100.0
  134. t1 = 10.0
  135. t = np.linspace(0, t1, 10000)
  136. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  137. tf, f = compute_frequency(t, phase)
  138. abserr = np.max(np.abs(f - chirp_geometric(tf, f0, f1, t1)))
  139. assert abserr < 1e-6
  140. def test_hyperbolic_at_zero(self):
  141. w = waveforms.chirp(t=0, f0=10.0, f1=1.0, t1=1.0, method='hyperbolic')
  142. assert_almost_equal(w, 1.0)
  143. def test_hyperbolic_freq_01(self):
  144. method = 'hyperbolic'
  145. t1 = 1.0
  146. t = np.linspace(0, t1, 10000)
  147. # f0 f1
  148. cases = [[10.0, 1.0],
  149. [1.0, 10.0],
  150. [-10.0, -1.0],
  151. [-1.0, -10.0]]
  152. for f0, f1 in cases:
  153. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  154. tf, f = compute_frequency(t, phase)
  155. expected = chirp_hyperbolic(tf, f0, f1, t1)
  156. xp_assert_close(f, expected, atol=1e-7)
  157. def test_hyperbolic_zero_freq(self):
  158. # f0=0 or f1=0 must raise a ValueError.
  159. method = 'hyperbolic'
  160. t1 = 1.0
  161. t = np.linspace(0, t1, 5)
  162. assert_raises(ValueError, waveforms.chirp, t, 0, t1, 1, method)
  163. assert_raises(ValueError, waveforms.chirp, t, 1, t1, 0, method)
  164. def test_unknown_method(self):
  165. method = "foo"
  166. f0 = 10.0
  167. f1 = 20.0
  168. t1 = 1.0
  169. t = np.linspace(0, t1, 10)
  170. assert_raises(ValueError, waveforms.chirp, t, f0, t1, f1, method)
  171. def test_integer_t1(self):
  172. f0 = 10.0
  173. f1 = 20.0
  174. t = np.linspace(-1, 1, 11)
  175. t1 = 3.0
  176. float_result = waveforms.chirp(t, f0, t1, f1)
  177. t1 = 3
  178. int_result = waveforms.chirp(t, f0, t1, f1)
  179. err_msg = "Integer input 't1=3' gives wrong result"
  180. xp_assert_equal(int_result, float_result, err_msg=err_msg)
  181. def test_integer_f0(self):
  182. f1 = 20.0
  183. t1 = 3.0
  184. t = np.linspace(-1, 1, 11)
  185. f0 = 10.0
  186. float_result = waveforms.chirp(t, f0, t1, f1)
  187. f0 = 10
  188. int_result = waveforms.chirp(t, f0, t1, f1)
  189. err_msg = "Integer input 'f0=10' gives wrong result"
  190. xp_assert_equal(int_result, float_result, err_msg=err_msg)
  191. def test_integer_f1(self):
  192. f0 = 10.0
  193. t1 = 3.0
  194. t = np.linspace(-1, 1, 11)
  195. f1 = 20.0
  196. float_result = waveforms.chirp(t, f0, t1, f1)
  197. f1 = 20
  198. int_result = waveforms.chirp(t, f0, t1, f1)
  199. err_msg = "Integer input 'f1=20' gives wrong result"
  200. xp_assert_equal(int_result, float_result, err_msg=err_msg)
  201. def test_integer_all(self):
  202. f0 = 10
  203. t1 = 3
  204. f1 = 20
  205. t = np.linspace(-1, 1, 11)
  206. float_result = waveforms.chirp(t, float(f0), float(t1), float(f1))
  207. int_result = waveforms.chirp(t, f0, t1, f1)
  208. err_msg = "Integer input 'f0=10, t1=3, f1=20' gives wrong result"
  209. xp_assert_equal(int_result, float_result, err_msg=err_msg)
  210. class TestSweepPoly:
  211. def test_sweep_poly_quad1(self):
  212. p = np.poly1d([1.0, 0.0, 1.0])
  213. t = np.linspace(0, 3.0, 10000)
  214. phase = waveforms._sweep_poly_phase(t, p)
  215. tf, f = compute_frequency(t, phase)
  216. expected = p(tf)
  217. abserr = np.max(np.abs(f - expected))
  218. assert abserr < 1e-6
  219. def test_sweep_poly_const(self):
  220. p = np.poly1d(2.0)
  221. t = np.linspace(0, 3.0, 10000)
  222. phase = waveforms._sweep_poly_phase(t, p)
  223. tf, f = compute_frequency(t, phase)
  224. expected = p(tf)
  225. abserr = np.max(np.abs(f - expected))
  226. assert abserr < 1e-6
  227. def test_sweep_poly_linear(self):
  228. p = np.poly1d([-1.0, 10.0])
  229. t = np.linspace(0, 3.0, 10000)
  230. phase = waveforms._sweep_poly_phase(t, p)
  231. tf, f = compute_frequency(t, phase)
  232. expected = p(tf)
  233. abserr = np.max(np.abs(f - expected))
  234. assert abserr < 1e-6
  235. def test_sweep_poly_quad2(self):
  236. p = np.poly1d([1.0, 0.0, -2.0])
  237. t = np.linspace(0, 3.0, 10000)
  238. phase = waveforms._sweep_poly_phase(t, p)
  239. tf, f = compute_frequency(t, phase)
  240. expected = p(tf)
  241. abserr = np.max(np.abs(f - expected))
  242. assert abserr < 1e-6
  243. def test_sweep_poly_cubic(self):
  244. p = np.poly1d([2.0, 1.0, 0.0, -2.0])
  245. t = np.linspace(0, 2.0, 10000)
  246. phase = waveforms._sweep_poly_phase(t, p)
  247. tf, f = compute_frequency(t, phase)
  248. expected = p(tf)
  249. abserr = np.max(np.abs(f - expected))
  250. assert abserr < 1e-6
  251. def test_sweep_poly_cubic2(self):
  252. """Use an array of coefficients instead of a poly1d."""
  253. p = np.array([2.0, 1.0, 0.0, -2.0])
  254. t = np.linspace(0, 2.0, 10000)
  255. phase = waveforms._sweep_poly_phase(t, p)
  256. tf, f = compute_frequency(t, phase)
  257. expected = np.poly1d(p)(tf)
  258. abserr = np.max(np.abs(f - expected))
  259. assert abserr < 1e-6
  260. def test_sweep_poly_cubic3(self):
  261. """Use a list of coefficients instead of a poly1d."""
  262. p = [2.0, 1.0, 0.0, -2.0]
  263. t = np.linspace(0, 2.0, 10000)
  264. phase = waveforms._sweep_poly_phase(t, p)
  265. tf, f = compute_frequency(t, phase)
  266. expected = np.poly1d(p)(tf)
  267. abserr = np.max(np.abs(f - expected))
  268. assert abserr < 1e-6
  269. class TestGaussPulse:
  270. def test_integer_fc(self):
  271. float_result = waveforms.gausspulse('cutoff', fc=1000.0)
  272. int_result = waveforms.gausspulse('cutoff', fc=1000)
  273. err_msg = "Integer input 'fc=1000' gives wrong result"
  274. xp_assert_equal(int_result, float_result, err_msg=err_msg)
  275. def test_integer_bw(self):
  276. float_result = waveforms.gausspulse('cutoff', bw=1.0)
  277. int_result = waveforms.gausspulse('cutoff', bw=1)
  278. err_msg = "Integer input 'bw=1' gives wrong result"
  279. xp_assert_equal(int_result, float_result, err_msg=err_msg)
  280. def test_integer_bwr(self):
  281. float_result = waveforms.gausspulse('cutoff', bwr=-6.0)
  282. int_result = waveforms.gausspulse('cutoff', bwr=-6)
  283. err_msg = "Integer input 'bwr=-6' gives wrong result"
  284. xp_assert_equal(int_result, float_result, err_msg=err_msg)
  285. def test_integer_tpr(self):
  286. float_result = waveforms.gausspulse('cutoff', tpr=-60.0)
  287. int_result = waveforms.gausspulse('cutoff', tpr=-60)
  288. err_msg = "Integer input 'tpr=-60' gives wrong result"
  289. xp_assert_equal(int_result, float_result, err_msg=err_msg)
  290. class TestUnitImpulse:
  291. def test_no_index(self):
  292. xp_assert_equal(waveforms.unit_impulse(7),
  293. np.asarray([1.0, 0, 0, 0, 0, 0, 0]))
  294. xp_assert_equal(waveforms.unit_impulse((3, 3)),
  295. np.asarray([[1.0, 0, 0], [0, 0, 0], [0, 0, 0]]))
  296. def test_index(self):
  297. xp_assert_equal(waveforms.unit_impulse(10, 3),
  298. np.asarray([0.0, 0, 0, 1, 0, 0, 0, 0, 0, 0]))
  299. xp_assert_equal(waveforms.unit_impulse((3, 3), (1, 1)),
  300. np.asarray([[0.0, 0, 0], [0, 1, 0], [0, 0, 0]]))
  301. # Broadcasting
  302. imp = waveforms.unit_impulse((4, 4), 2)
  303. xp_assert_equal(imp, np.asarray([[0.0, 0, 0, 0],
  304. [0.0, 0, 0, 0],
  305. [0.0, 0, 1, 0],
  306. [0.0, 0, 0, 0]]))
  307. def test_mid(self):
  308. xp_assert_equal(waveforms.unit_impulse((3, 3), 'mid'),
  309. np.asarray([[0.0, 0, 0], [0, 1, 0], [0, 0, 0]]))
  310. xp_assert_equal(waveforms.unit_impulse(9, 'mid'),
  311. np.asarray([0.0, 0, 0, 0, 1, 0, 0, 0, 0]))
  312. def test_dtype(self):
  313. imp = waveforms.unit_impulse(7)
  314. assert np.issubdtype(imp.dtype, np.floating)
  315. imp = waveforms.unit_impulse(5, 3, dtype=int)
  316. assert np.issubdtype(imp.dtype, np.integer)
  317. imp = waveforms.unit_impulse((5, 2), (3, 1), dtype=complex)
  318. assert np.issubdtype(imp.dtype, np.complexfloating)
  319. class TestSawtoothWaveform:
  320. def test_dtype(self):
  321. waveform = waveforms.sawtooth(
  322. np.array(1, dtype=np.float32), width=np.float32(1)
  323. )
  324. assert waveform.dtype == np.float64
  325. waveform = waveforms.sawtooth(1)
  326. assert waveform.dtype == np.float64
  327. class TestSquareWaveform:
  328. def test_dtype(self):
  329. waveform = waveforms.square(np.array(1, dtype=np.float32), duty=np.float32(0.5))
  330. assert waveform.dtype == np.float64
  331. waveform = waveforms.square(1)
  332. assert waveform.dtype == np.float64