test_hermite.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. """Tests for hermite module.
  2. """
  3. from functools import reduce
  4. import numpy as np
  5. import numpy.polynomial.hermite as herm
  6. from numpy.polynomial.polynomial import polyval
  7. from numpy.testing import assert_, assert_almost_equal, assert_equal, assert_raises
  8. H0 = np.array([1])
  9. H1 = np.array([0, 2])
  10. H2 = np.array([-2, 0, 4])
  11. H3 = np.array([0, -12, 0, 8])
  12. H4 = np.array([12, 0, -48, 0, 16])
  13. H5 = np.array([0, 120, 0, -160, 0, 32])
  14. H6 = np.array([-120, 0, 720, 0, -480, 0, 64])
  15. H7 = np.array([0, -1680, 0, 3360, 0, -1344, 0, 128])
  16. H8 = np.array([1680, 0, -13440, 0, 13440, 0, -3584, 0, 256])
  17. H9 = np.array([0, 30240, 0, -80640, 0, 48384, 0, -9216, 0, 512])
  18. Hlist = [H0, H1, H2, H3, H4, H5, H6, H7, H8, H9]
  19. def trim(x):
  20. return herm.hermtrim(x, tol=1e-6)
  21. class TestConstants:
  22. def test_hermdomain(self):
  23. assert_equal(herm.hermdomain, [-1, 1])
  24. def test_hermzero(self):
  25. assert_equal(herm.hermzero, [0])
  26. def test_hermone(self):
  27. assert_equal(herm.hermone, [1])
  28. def test_hermx(self):
  29. assert_equal(herm.hermx, [0, .5])
  30. class TestArithmetic:
  31. x = np.linspace(-3, 3, 100)
  32. def test_hermadd(self):
  33. for i in range(5):
  34. for j in range(5):
  35. msg = f"At i={i}, j={j}"
  36. tgt = np.zeros(max(i, j) + 1)
  37. tgt[i] += 1
  38. tgt[j] += 1
  39. res = herm.hermadd([0] * i + [1], [0] * j + [1])
  40. assert_equal(trim(res), trim(tgt), err_msg=msg)
  41. def test_hermsub(self):
  42. for i in range(5):
  43. for j in range(5):
  44. msg = f"At i={i}, j={j}"
  45. tgt = np.zeros(max(i, j) + 1)
  46. tgt[i] += 1
  47. tgt[j] -= 1
  48. res = herm.hermsub([0] * i + [1], [0] * j + [1])
  49. assert_equal(trim(res), trim(tgt), err_msg=msg)
  50. def test_hermmulx(self):
  51. assert_equal(herm.hermmulx([0]), [0])
  52. assert_equal(herm.hermmulx([1]), [0, .5])
  53. for i in range(1, 5):
  54. ser = [0] * i + [1]
  55. tgt = [0] * (i - 1) + [i, 0, .5]
  56. assert_equal(herm.hermmulx(ser), tgt)
  57. def test_hermmul(self):
  58. # check values of result
  59. for i in range(5):
  60. pol1 = [0] * i + [1]
  61. val1 = herm.hermval(self.x, pol1)
  62. for j in range(5):
  63. msg = f"At i={i}, j={j}"
  64. pol2 = [0] * j + [1]
  65. val2 = herm.hermval(self.x, pol2)
  66. pol3 = herm.hermmul(pol1, pol2)
  67. val3 = herm.hermval(self.x, pol3)
  68. assert_(len(pol3) == i + j + 1, msg)
  69. assert_almost_equal(val3, val1 * val2, err_msg=msg)
  70. def test_hermdiv(self):
  71. for i in range(5):
  72. for j in range(5):
  73. msg = f"At i={i}, j={j}"
  74. ci = [0] * i + [1]
  75. cj = [0] * j + [1]
  76. tgt = herm.hermadd(ci, cj)
  77. quo, rem = herm.hermdiv(tgt, ci)
  78. res = herm.hermadd(herm.hermmul(quo, ci), rem)
  79. assert_equal(trim(res), trim(tgt), err_msg=msg)
  80. def test_hermpow(self):
  81. for i in range(5):
  82. for j in range(5):
  83. msg = f"At i={i}, j={j}"
  84. c = np.arange(i + 1)
  85. tgt = reduce(herm.hermmul, [c] * j, np.array([1]))
  86. res = herm.hermpow(c, j)
  87. assert_equal(trim(res), trim(tgt), err_msg=msg)
  88. class TestEvaluation:
  89. # coefficients of 1 + 2*x + 3*x**2
  90. c1d = np.array([2.5, 1., .75])
  91. c2d = np.einsum('i,j->ij', c1d, c1d)
  92. c3d = np.einsum('i,j,k->ijk', c1d, c1d, c1d)
  93. # some random values in [-1, 1)
  94. x = np.random.random((3, 5)) * 2 - 1
  95. y = polyval(x, [1., 2., 3.])
  96. def test_hermval(self):
  97. # check empty input
  98. assert_equal(herm.hermval([], [1]).size, 0)
  99. # check normal input)
  100. x = np.linspace(-1, 1)
  101. y = [polyval(x, c) for c in Hlist]
  102. for i in range(10):
  103. msg = f"At i={i}"
  104. tgt = y[i]
  105. res = herm.hermval(x, [0] * i + [1])
  106. assert_almost_equal(res, tgt, err_msg=msg)
  107. # check that shape is preserved
  108. for i in range(3):
  109. dims = [2] * i
  110. x = np.zeros(dims)
  111. assert_equal(herm.hermval(x, [1]).shape, dims)
  112. assert_equal(herm.hermval(x, [1, 0]).shape, dims)
  113. assert_equal(herm.hermval(x, [1, 0, 0]).shape, dims)
  114. def test_hermval2d(self):
  115. x1, x2, x3 = self.x
  116. y1, y2, y3 = self.y
  117. # test exceptions
  118. assert_raises(ValueError, herm.hermval2d, x1, x2[:2], self.c2d)
  119. # test values
  120. tgt = y1 * y2
  121. res = herm.hermval2d(x1, x2, self.c2d)
  122. assert_almost_equal(res, tgt)
  123. # test shape
  124. z = np.ones((2, 3))
  125. res = herm.hermval2d(z, z, self.c2d)
  126. assert_(res.shape == (2, 3))
  127. def test_hermval3d(self):
  128. x1, x2, x3 = self.x
  129. y1, y2, y3 = self.y
  130. # test exceptions
  131. assert_raises(ValueError, herm.hermval3d, x1, x2, x3[:2], self.c3d)
  132. # test values
  133. tgt = y1 * y2 * y3
  134. res = herm.hermval3d(x1, x2, x3, self.c3d)
  135. assert_almost_equal(res, tgt)
  136. # test shape
  137. z = np.ones((2, 3))
  138. res = herm.hermval3d(z, z, z, self.c3d)
  139. assert_(res.shape == (2, 3))
  140. def test_hermgrid2d(self):
  141. x1, x2, x3 = self.x
  142. y1, y2, y3 = self.y
  143. # test values
  144. tgt = np.einsum('i,j->ij', y1, y2)
  145. res = herm.hermgrid2d(x1, x2, self.c2d)
  146. assert_almost_equal(res, tgt)
  147. # test shape
  148. z = np.ones((2, 3))
  149. res = herm.hermgrid2d(z, z, self.c2d)
  150. assert_(res.shape == (2, 3) * 2)
  151. def test_hermgrid3d(self):
  152. x1, x2, x3 = self.x
  153. y1, y2, y3 = self.y
  154. # test values
  155. tgt = np.einsum('i,j,k->ijk', y1, y2, y3)
  156. res = herm.hermgrid3d(x1, x2, x3, self.c3d)
  157. assert_almost_equal(res, tgt)
  158. # test shape
  159. z = np.ones((2, 3))
  160. res = herm.hermgrid3d(z, z, z, self.c3d)
  161. assert_(res.shape == (2, 3) * 3)
  162. class TestIntegral:
  163. def test_hermint(self):
  164. # check exceptions
  165. assert_raises(TypeError, herm.hermint, [0], .5)
  166. assert_raises(ValueError, herm.hermint, [0], -1)
  167. assert_raises(ValueError, herm.hermint, [0], 1, [0, 0])
  168. assert_raises(ValueError, herm.hermint, [0], lbnd=[0])
  169. assert_raises(ValueError, herm.hermint, [0], scl=[0])
  170. assert_raises(TypeError, herm.hermint, [0], axis=.5)
  171. # test integration of zero polynomial
  172. for i in range(2, 5):
  173. k = [0] * (i - 2) + [1]
  174. res = herm.hermint([0], m=i, k=k)
  175. assert_almost_equal(res, [0, .5])
  176. # check single integration with integration constant
  177. for i in range(5):
  178. scl = i + 1
  179. pol = [0] * i + [1]
  180. tgt = [i] + [0] * i + [1 / scl]
  181. hermpol = herm.poly2herm(pol)
  182. hermint = herm.hermint(hermpol, m=1, k=[i])
  183. res = herm.herm2poly(hermint)
  184. assert_almost_equal(trim(res), trim(tgt))
  185. # check single integration with integration constant and lbnd
  186. for i in range(5):
  187. scl = i + 1
  188. pol = [0] * i + [1]
  189. hermpol = herm.poly2herm(pol)
  190. hermint = herm.hermint(hermpol, m=1, k=[i], lbnd=-1)
  191. assert_almost_equal(herm.hermval(-1, hermint), i)
  192. # check single integration with integration constant and scaling
  193. for i in range(5):
  194. scl = i + 1
  195. pol = [0] * i + [1]
  196. tgt = [i] + [0] * i + [2 / scl]
  197. hermpol = herm.poly2herm(pol)
  198. hermint = herm.hermint(hermpol, m=1, k=[i], scl=2)
  199. res = herm.herm2poly(hermint)
  200. assert_almost_equal(trim(res), trim(tgt))
  201. # check multiple integrations with default k
  202. for i in range(5):
  203. for j in range(2, 5):
  204. pol = [0] * i + [1]
  205. tgt = pol[:]
  206. for k in range(j):
  207. tgt = herm.hermint(tgt, m=1)
  208. res = herm.hermint(pol, m=j)
  209. assert_almost_equal(trim(res), trim(tgt))
  210. # check multiple integrations with defined k
  211. for i in range(5):
  212. for j in range(2, 5):
  213. pol = [0] * i + [1]
  214. tgt = pol[:]
  215. for k in range(j):
  216. tgt = herm.hermint(tgt, m=1, k=[k])
  217. res = herm.hermint(pol, m=j, k=list(range(j)))
  218. assert_almost_equal(trim(res), trim(tgt))
  219. # check multiple integrations with lbnd
  220. for i in range(5):
  221. for j in range(2, 5):
  222. pol = [0] * i + [1]
  223. tgt = pol[:]
  224. for k in range(j):
  225. tgt = herm.hermint(tgt, m=1, k=[k], lbnd=-1)
  226. res = herm.hermint(pol, m=j, k=list(range(j)), lbnd=-1)
  227. assert_almost_equal(trim(res), trim(tgt))
  228. # check multiple integrations with scaling
  229. for i in range(5):
  230. for j in range(2, 5):
  231. pol = [0] * i + [1]
  232. tgt = pol[:]
  233. for k in range(j):
  234. tgt = herm.hermint(tgt, m=1, k=[k], scl=2)
  235. res = herm.hermint(pol, m=j, k=list(range(j)), scl=2)
  236. assert_almost_equal(trim(res), trim(tgt))
  237. def test_hermint_axis(self):
  238. # check that axis keyword works
  239. c2d = np.random.random((3, 4))
  240. tgt = np.vstack([herm.hermint(c) for c in c2d.T]).T
  241. res = herm.hermint(c2d, axis=0)
  242. assert_almost_equal(res, tgt)
  243. tgt = np.vstack([herm.hermint(c) for c in c2d])
  244. res = herm.hermint(c2d, axis=1)
  245. assert_almost_equal(res, tgt)
  246. tgt = np.vstack([herm.hermint(c, k=3) for c in c2d])
  247. res = herm.hermint(c2d, k=3, axis=1)
  248. assert_almost_equal(res, tgt)
  249. class TestDerivative:
  250. def test_hermder(self):
  251. # check exceptions
  252. assert_raises(TypeError, herm.hermder, [0], .5)
  253. assert_raises(ValueError, herm.hermder, [0], -1)
  254. # check that zeroth derivative does nothing
  255. for i in range(5):
  256. tgt = [0] * i + [1]
  257. res = herm.hermder(tgt, m=0)
  258. assert_equal(trim(res), trim(tgt))
  259. # check that derivation is the inverse of integration
  260. for i in range(5):
  261. for j in range(2, 5):
  262. tgt = [0] * i + [1]
  263. res = herm.hermder(herm.hermint(tgt, m=j), m=j)
  264. assert_almost_equal(trim(res), trim(tgt))
  265. # check derivation with scaling
  266. for i in range(5):
  267. for j in range(2, 5):
  268. tgt = [0] * i + [1]
  269. res = herm.hermder(herm.hermint(tgt, m=j, scl=2), m=j, scl=.5)
  270. assert_almost_equal(trim(res), trim(tgt))
  271. def test_hermder_axis(self):
  272. # check that axis keyword works
  273. c2d = np.random.random((3, 4))
  274. tgt = np.vstack([herm.hermder(c) for c in c2d.T]).T
  275. res = herm.hermder(c2d, axis=0)
  276. assert_almost_equal(res, tgt)
  277. tgt = np.vstack([herm.hermder(c) for c in c2d])
  278. res = herm.hermder(c2d, axis=1)
  279. assert_almost_equal(res, tgt)
  280. class TestVander:
  281. # some random values in [-1, 1)
  282. x = np.random.random((3, 5)) * 2 - 1
  283. def test_hermvander(self):
  284. # check for 1d x
  285. x = np.arange(3)
  286. v = herm.hermvander(x, 3)
  287. assert_(v.shape == (3, 4))
  288. for i in range(4):
  289. coef = [0] * i + [1]
  290. assert_almost_equal(v[..., i], herm.hermval(x, coef))
  291. # check for 2d x
  292. x = np.array([[1, 2], [3, 4], [5, 6]])
  293. v = herm.hermvander(x, 3)
  294. assert_(v.shape == (3, 2, 4))
  295. for i in range(4):
  296. coef = [0] * i + [1]
  297. assert_almost_equal(v[..., i], herm.hermval(x, coef))
  298. def test_hermvander2d(self):
  299. # also tests hermval2d for non-square coefficient array
  300. x1, x2, x3 = self.x
  301. c = np.random.random((2, 3))
  302. van = herm.hermvander2d(x1, x2, [1, 2])
  303. tgt = herm.hermval2d(x1, x2, c)
  304. res = np.dot(van, c.flat)
  305. assert_almost_equal(res, tgt)
  306. # check shape
  307. van = herm.hermvander2d([x1], [x2], [1, 2])
  308. assert_(van.shape == (1, 5, 6))
  309. def test_hermvander3d(self):
  310. # also tests hermval3d for non-square coefficient array
  311. x1, x2, x3 = self.x
  312. c = np.random.random((2, 3, 4))
  313. van = herm.hermvander3d(x1, x2, x3, [1, 2, 3])
  314. tgt = herm.hermval3d(x1, x2, x3, c)
  315. res = np.dot(van, c.flat)
  316. assert_almost_equal(res, tgt)
  317. # check shape
  318. van = herm.hermvander3d([x1], [x2], [x3], [1, 2, 3])
  319. assert_(van.shape == (1, 5, 24))
  320. class TestFitting:
  321. def test_hermfit(self):
  322. def f(x):
  323. return x * (x - 1) * (x - 2)
  324. def f2(x):
  325. return x**4 + x**2 + 1
  326. # Test exceptions
  327. assert_raises(ValueError, herm.hermfit, [1], [1], -1)
  328. assert_raises(TypeError, herm.hermfit, [[1]], [1], 0)
  329. assert_raises(TypeError, herm.hermfit, [], [1], 0)
  330. assert_raises(TypeError, herm.hermfit, [1], [[[1]]], 0)
  331. assert_raises(TypeError, herm.hermfit, [1, 2], [1], 0)
  332. assert_raises(TypeError, herm.hermfit, [1], [1, 2], 0)
  333. assert_raises(TypeError, herm.hermfit, [1], [1], 0, w=[[1]])
  334. assert_raises(TypeError, herm.hermfit, [1], [1], 0, w=[1, 1])
  335. assert_raises(ValueError, herm.hermfit, [1], [1], [-1,])
  336. assert_raises(ValueError, herm.hermfit, [1], [1], [2, -1, 6])
  337. assert_raises(TypeError, herm.hermfit, [1], [1], [])
  338. # Test fit
  339. x = np.linspace(0, 2)
  340. y = f(x)
  341. #
  342. coef3 = herm.hermfit(x, y, 3)
  343. assert_equal(len(coef3), 4)
  344. assert_almost_equal(herm.hermval(x, coef3), y)
  345. coef3 = herm.hermfit(x, y, [0, 1, 2, 3])
  346. assert_equal(len(coef3), 4)
  347. assert_almost_equal(herm.hermval(x, coef3), y)
  348. #
  349. coef4 = herm.hermfit(x, y, 4)
  350. assert_equal(len(coef4), 5)
  351. assert_almost_equal(herm.hermval(x, coef4), y)
  352. coef4 = herm.hermfit(x, y, [0, 1, 2, 3, 4])
  353. assert_equal(len(coef4), 5)
  354. assert_almost_equal(herm.hermval(x, coef4), y)
  355. # check things still work if deg is not in strict increasing
  356. coef4 = herm.hermfit(x, y, [2, 3, 4, 1, 0])
  357. assert_equal(len(coef4), 5)
  358. assert_almost_equal(herm.hermval(x, coef4), y)
  359. #
  360. coef2d = herm.hermfit(x, np.array([y, y]).T, 3)
  361. assert_almost_equal(coef2d, np.array([coef3, coef3]).T)
  362. coef2d = herm.hermfit(x, np.array([y, y]).T, [0, 1, 2, 3])
  363. assert_almost_equal(coef2d, np.array([coef3, coef3]).T)
  364. # test weighting
  365. w = np.zeros_like(x)
  366. yw = y.copy()
  367. w[1::2] = 1
  368. y[0::2] = 0
  369. wcoef3 = herm.hermfit(x, yw, 3, w=w)
  370. assert_almost_equal(wcoef3, coef3)
  371. wcoef3 = herm.hermfit(x, yw, [0, 1, 2, 3], w=w)
  372. assert_almost_equal(wcoef3, coef3)
  373. #
  374. wcoef2d = herm.hermfit(x, np.array([yw, yw]).T, 3, w=w)
  375. assert_almost_equal(wcoef2d, np.array([coef3, coef3]).T)
  376. wcoef2d = herm.hermfit(x, np.array([yw, yw]).T, [0, 1, 2, 3], w=w)
  377. assert_almost_equal(wcoef2d, np.array([coef3, coef3]).T)
  378. # test scaling with complex values x points whose square
  379. # is zero when summed.
  380. x = [1, 1j, -1, -1j]
  381. assert_almost_equal(herm.hermfit(x, x, 1), [0, .5])
  382. assert_almost_equal(herm.hermfit(x, x, [0, 1]), [0, .5])
  383. # test fitting only even Legendre polynomials
  384. x = np.linspace(-1, 1)
  385. y = f2(x)
  386. coef1 = herm.hermfit(x, y, 4)
  387. assert_almost_equal(herm.hermval(x, coef1), y)
  388. coef2 = herm.hermfit(x, y, [0, 2, 4])
  389. assert_almost_equal(herm.hermval(x, coef2), y)
  390. assert_almost_equal(coef1, coef2)
  391. class TestCompanion:
  392. def test_raises(self):
  393. assert_raises(ValueError, herm.hermcompanion, [])
  394. assert_raises(ValueError, herm.hermcompanion, [1])
  395. def test_dimensions(self):
  396. for i in range(1, 5):
  397. coef = [0] * i + [1]
  398. assert_(herm.hermcompanion(coef).shape == (i, i))
  399. def test_linear_root(self):
  400. assert_(herm.hermcompanion([1, 2])[0, 0] == -.25)
  401. class TestGauss:
  402. def test_100(self):
  403. x, w = herm.hermgauss(100)
  404. # test orthogonality. Note that the results need to be normalized,
  405. # otherwise the huge values that can arise from fast growing
  406. # functions like Laguerre can be very confusing.
  407. v = herm.hermvander(x, 99)
  408. vv = np.dot(v.T * w, v)
  409. vd = 1 / np.sqrt(vv.diagonal())
  410. vv = vd[:, None] * vv * vd
  411. assert_almost_equal(vv, np.eye(100))
  412. # check that the integral of 1 is correct
  413. tgt = np.sqrt(np.pi)
  414. assert_almost_equal(w.sum(), tgt)
  415. class TestMisc:
  416. def test_hermfromroots(self):
  417. res = herm.hermfromroots([])
  418. assert_almost_equal(trim(res), [1])
  419. for i in range(1, 5):
  420. roots = np.cos(np.linspace(-np.pi, 0, 2 * i + 1)[1::2])
  421. pol = herm.hermfromroots(roots)
  422. res = herm.hermval(roots, pol)
  423. tgt = 0
  424. assert_(len(pol) == i + 1)
  425. assert_almost_equal(herm.herm2poly(pol)[-1], 1)
  426. assert_almost_equal(res, tgt)
  427. def test_hermroots(self):
  428. assert_almost_equal(herm.hermroots([1]), [])
  429. assert_almost_equal(herm.hermroots([1, 1]), [-.5])
  430. for i in range(2, 5):
  431. tgt = np.linspace(-1, 1, i)
  432. res = herm.hermroots(herm.hermfromroots(tgt))
  433. assert_almost_equal(trim(res), trim(tgt))
  434. def test_hermtrim(self):
  435. coef = [2, -1, 1, 0]
  436. # Test exceptions
  437. assert_raises(ValueError, herm.hermtrim, coef, -1)
  438. # Test results
  439. assert_equal(herm.hermtrim(coef), coef[:-1])
  440. assert_equal(herm.hermtrim(coef, 1), coef[:-3])
  441. assert_equal(herm.hermtrim(coef, 2), [0])
  442. def test_hermline(self):
  443. assert_equal(herm.hermline(3, 4), [3, 2])
  444. def test_herm2poly(self):
  445. for i in range(10):
  446. assert_almost_equal(herm.herm2poly([0] * i + [1]), Hlist[i])
  447. def test_poly2herm(self):
  448. for i in range(10):
  449. assert_almost_equal(herm.poly2herm(Hlist[i]), [0] * i + [1])
  450. def test_weight(self):
  451. x = np.linspace(-5, 5, 11)
  452. tgt = np.exp(-x**2)
  453. res = herm.hermweight(x)
  454. assert_almost_equal(res, tgt)