test_polynomial.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. """Tests for polynomial module.
  2. """
  3. from functools import reduce
  4. from fractions import Fraction
  5. import numpy as np
  6. import numpy.polynomial.polynomial as poly
  7. import numpy.polynomial.polyutils as pu
  8. import pickle
  9. from copy import deepcopy
  10. from numpy.testing import (
  11. assert_almost_equal, assert_raises, assert_equal, assert_,
  12. assert_array_equal, assert_raises_regex, assert_warns)
  13. def trim(x):
  14. return poly.polytrim(x, tol=1e-6)
  15. T0 = [1]
  16. T1 = [0, 1]
  17. T2 = [-1, 0, 2]
  18. T3 = [0, -3, 0, 4]
  19. T4 = [1, 0, -8, 0, 8]
  20. T5 = [0, 5, 0, -20, 0, 16]
  21. T6 = [-1, 0, 18, 0, -48, 0, 32]
  22. T7 = [0, -7, 0, 56, 0, -112, 0, 64]
  23. T8 = [1, 0, -32, 0, 160, 0, -256, 0, 128]
  24. T9 = [0, 9, 0, -120, 0, 432, 0, -576, 0, 256]
  25. Tlist = [T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]
  26. class TestConstants:
  27. def test_polydomain(self):
  28. assert_equal(poly.polydomain, [-1, 1])
  29. def test_polyzero(self):
  30. assert_equal(poly.polyzero, [0])
  31. def test_polyone(self):
  32. assert_equal(poly.polyone, [1])
  33. def test_polyx(self):
  34. assert_equal(poly.polyx, [0, 1])
  35. def test_copy(self):
  36. x = poly.Polynomial([1, 2, 3])
  37. y = deepcopy(x)
  38. assert_equal(x, y)
  39. def test_pickle(self):
  40. x = poly.Polynomial([1, 2, 3])
  41. y = pickle.loads(pickle.dumps(x))
  42. assert_equal(x, y)
  43. class TestArithmetic:
  44. def test_polyadd(self):
  45. for i in range(5):
  46. for j in range(5):
  47. msg = f"At i={i}, j={j}"
  48. tgt = np.zeros(max(i, j) + 1)
  49. tgt[i] += 1
  50. tgt[j] += 1
  51. res = poly.polyadd([0]*i + [1], [0]*j + [1])
  52. assert_equal(trim(res), trim(tgt), err_msg=msg)
  53. def test_polysub(self):
  54. for i in range(5):
  55. for j in range(5):
  56. msg = f"At i={i}, j={j}"
  57. tgt = np.zeros(max(i, j) + 1)
  58. tgt[i] += 1
  59. tgt[j] -= 1
  60. res = poly.polysub([0]*i + [1], [0]*j + [1])
  61. assert_equal(trim(res), trim(tgt), err_msg=msg)
  62. def test_polymulx(self):
  63. assert_equal(poly.polymulx([0]), [0])
  64. assert_equal(poly.polymulx([1]), [0, 1])
  65. for i in range(1, 5):
  66. ser = [0]*i + [1]
  67. tgt = [0]*(i + 1) + [1]
  68. assert_equal(poly.polymulx(ser), tgt)
  69. def test_polymul(self):
  70. for i in range(5):
  71. for j in range(5):
  72. msg = f"At i={i}, j={j}"
  73. tgt = np.zeros(i + j + 1)
  74. tgt[i + j] += 1
  75. res = poly.polymul([0]*i + [1], [0]*j + [1])
  76. assert_equal(trim(res), trim(tgt), err_msg=msg)
  77. def test_polydiv(self):
  78. # check zero division
  79. assert_raises(ZeroDivisionError, poly.polydiv, [1], [0])
  80. # check scalar division
  81. quo, rem = poly.polydiv([2], [2])
  82. assert_equal((quo, rem), (1, 0))
  83. quo, rem = poly.polydiv([2, 2], [2])
  84. assert_equal((quo, rem), ((1, 1), 0))
  85. # check rest.
  86. for i in range(5):
  87. for j in range(5):
  88. msg = f"At i={i}, j={j}"
  89. ci = [0]*i + [1, 2]
  90. cj = [0]*j + [1, 2]
  91. tgt = poly.polyadd(ci, cj)
  92. quo, rem = poly.polydiv(tgt, ci)
  93. res = poly.polyadd(poly.polymul(quo, ci), rem)
  94. assert_equal(res, tgt, err_msg=msg)
  95. def test_polypow(self):
  96. for i in range(5):
  97. for j in range(5):
  98. msg = f"At i={i}, j={j}"
  99. c = np.arange(i + 1)
  100. tgt = reduce(poly.polymul, [c]*j, np.array([1]))
  101. res = poly.polypow(c, j)
  102. assert_equal(trim(res), trim(tgt), err_msg=msg)
  103. class TestFraction:
  104. def test_Fraction(self):
  105. # assert we can use Polynomials with coefficients of object dtype
  106. f = Fraction(2, 3)
  107. one = Fraction(1, 1)
  108. zero = Fraction(0, 1)
  109. p = poly.Polynomial([f, f], domain=[zero, one], window=[zero, one])
  110. x = 2 * p + p ** 2
  111. assert_equal(x.coef, np.array([Fraction(16, 9), Fraction(20, 9),
  112. Fraction(4, 9)], dtype=object))
  113. assert_equal(p.domain, [zero, one])
  114. assert_equal(p.coef.dtype, np.dtypes.ObjectDType())
  115. assert_(isinstance(p(f), Fraction))
  116. assert_equal(p(f), Fraction(10, 9))
  117. p_deriv = poly.Polynomial([Fraction(2, 3)], domain=[zero, one],
  118. window=[zero, one])
  119. assert_equal(p.deriv(), p_deriv)
  120. class TestEvaluation:
  121. # coefficients of 1 + 2*x + 3*x**2
  122. c1d = np.array([1., 2., 3.])
  123. c2d = np.einsum('i,j->ij', c1d, c1d)
  124. c3d = np.einsum('i,j,k->ijk', c1d, c1d, c1d)
  125. # some random values in [-1, 1)
  126. x = np.random.random((3, 5))*2 - 1
  127. y = poly.polyval(x, [1., 2., 3.])
  128. def test_polyval(self):
  129. #check empty input
  130. assert_equal(poly.polyval([], [1]).size, 0)
  131. #check normal input)
  132. x = np.linspace(-1, 1)
  133. y = [x**i for i in range(5)]
  134. for i in range(5):
  135. tgt = y[i]
  136. res = poly.polyval(x, [0]*i + [1])
  137. assert_almost_equal(res, tgt)
  138. tgt = x*(x**2 - 1)
  139. res = poly.polyval(x, [0, -1, 0, 1])
  140. assert_almost_equal(res, tgt)
  141. #check that shape is preserved
  142. for i in range(3):
  143. dims = [2]*i
  144. x = np.zeros(dims)
  145. assert_equal(poly.polyval(x, [1]).shape, dims)
  146. assert_equal(poly.polyval(x, [1, 0]).shape, dims)
  147. assert_equal(poly.polyval(x, [1, 0, 0]).shape, dims)
  148. #check masked arrays are processed correctly
  149. mask = [False, True, False]
  150. mx = np.ma.array([1, 2, 3], mask=mask)
  151. res = np.polyval([7, 5, 3], mx)
  152. assert_array_equal(res.mask, mask)
  153. #check subtypes of ndarray are preserved
  154. class C(np.ndarray):
  155. pass
  156. cx = np.array([1, 2, 3]).view(C)
  157. assert_equal(type(np.polyval([2, 3, 4], cx)), C)
  158. def test_polyvalfromroots(self):
  159. # check exception for broadcasting x values over root array with
  160. # too few dimensions
  161. assert_raises(ValueError, poly.polyvalfromroots,
  162. [1], [1], tensor=False)
  163. # check empty input
  164. assert_equal(poly.polyvalfromroots([], [1]).size, 0)
  165. assert_(poly.polyvalfromroots([], [1]).shape == (0,))
  166. # check empty input + multidimensional roots
  167. assert_equal(poly.polyvalfromroots([], [[1] * 5]).size, 0)
  168. assert_(poly.polyvalfromroots([], [[1] * 5]).shape == (5, 0))
  169. # check scalar input
  170. assert_equal(poly.polyvalfromroots(1, 1), 0)
  171. assert_(poly.polyvalfromroots(1, np.ones((3, 3))).shape == (3,))
  172. # check normal input)
  173. x = np.linspace(-1, 1)
  174. y = [x**i for i in range(5)]
  175. for i in range(1, 5):
  176. tgt = y[i]
  177. res = poly.polyvalfromroots(x, [0]*i)
  178. assert_almost_equal(res, tgt)
  179. tgt = x*(x - 1)*(x + 1)
  180. res = poly.polyvalfromroots(x, [-1, 0, 1])
  181. assert_almost_equal(res, tgt)
  182. # check that shape is preserved
  183. for i in range(3):
  184. dims = [2]*i
  185. x = np.zeros(dims)
  186. assert_equal(poly.polyvalfromroots(x, [1]).shape, dims)
  187. assert_equal(poly.polyvalfromroots(x, [1, 0]).shape, dims)
  188. assert_equal(poly.polyvalfromroots(x, [1, 0, 0]).shape, dims)
  189. # check compatibility with factorization
  190. ptest = [15, 2, -16, -2, 1]
  191. r = poly.polyroots(ptest)
  192. x = np.linspace(-1, 1)
  193. assert_almost_equal(poly.polyval(x, ptest),
  194. poly.polyvalfromroots(x, r))
  195. # check multidimensional arrays of roots and values
  196. # check tensor=False
  197. rshape = (3, 5)
  198. x = np.arange(-3, 2)
  199. r = np.random.randint(-5, 5, size=rshape)
  200. res = poly.polyvalfromroots(x, r, tensor=False)
  201. tgt = np.empty(r.shape[1:])
  202. for ii in range(tgt.size):
  203. tgt[ii] = poly.polyvalfromroots(x[ii], r[:, ii])
  204. assert_equal(res, tgt)
  205. # check tensor=True
  206. x = np.vstack([x, 2*x])
  207. res = poly.polyvalfromroots(x, r, tensor=True)
  208. tgt = np.empty(r.shape[1:] + x.shape)
  209. for ii in range(r.shape[1]):
  210. for jj in range(x.shape[0]):
  211. tgt[ii, jj, :] = poly.polyvalfromroots(x[jj], r[:, ii])
  212. assert_equal(res, tgt)
  213. def test_polyval2d(self):
  214. x1, x2, x3 = self.x
  215. y1, y2, y3 = self.y
  216. #test exceptions
  217. assert_raises_regex(ValueError, 'incompatible',
  218. poly.polyval2d, x1, x2[:2], self.c2d)
  219. #test values
  220. tgt = y1*y2
  221. res = poly.polyval2d(x1, x2, self.c2d)
  222. assert_almost_equal(res, tgt)
  223. #test shape
  224. z = np.ones((2, 3))
  225. res = poly.polyval2d(z, z, self.c2d)
  226. assert_(res.shape == (2, 3))
  227. def test_polyval3d(self):
  228. x1, x2, x3 = self.x
  229. y1, y2, y3 = self.y
  230. #test exceptions
  231. assert_raises_regex(ValueError, 'incompatible',
  232. poly.polyval3d, x1, x2, x3[:2], self.c3d)
  233. #test values
  234. tgt = y1*y2*y3
  235. res = poly.polyval3d(x1, x2, x3, self.c3d)
  236. assert_almost_equal(res, tgt)
  237. #test shape
  238. z = np.ones((2, 3))
  239. res = poly.polyval3d(z, z, z, self.c3d)
  240. assert_(res.shape == (2, 3))
  241. def test_polygrid2d(self):
  242. x1, x2, x3 = self.x
  243. y1, y2, y3 = self.y
  244. #test values
  245. tgt = np.einsum('i,j->ij', y1, y2)
  246. res = poly.polygrid2d(x1, x2, self.c2d)
  247. assert_almost_equal(res, tgt)
  248. #test shape
  249. z = np.ones((2, 3))
  250. res = poly.polygrid2d(z, z, self.c2d)
  251. assert_(res.shape == (2, 3)*2)
  252. def test_polygrid3d(self):
  253. x1, x2, x3 = self.x
  254. y1, y2, y3 = self.y
  255. #test values
  256. tgt = np.einsum('i,j,k->ijk', y1, y2, y3)
  257. res = poly.polygrid3d(x1, x2, x3, self.c3d)
  258. assert_almost_equal(res, tgt)
  259. #test shape
  260. z = np.ones((2, 3))
  261. res = poly.polygrid3d(z, z, z, self.c3d)
  262. assert_(res.shape == (2, 3)*3)
  263. class TestIntegral:
  264. def test_polyint(self):
  265. # check exceptions
  266. assert_raises(TypeError, poly.polyint, [0], .5)
  267. assert_raises(ValueError, poly.polyint, [0], -1)
  268. assert_raises(ValueError, poly.polyint, [0], 1, [0, 0])
  269. assert_raises(ValueError, poly.polyint, [0], lbnd=[0])
  270. assert_raises(ValueError, poly.polyint, [0], scl=[0])
  271. assert_raises(TypeError, poly.polyint, [0], axis=.5)
  272. assert_raises(TypeError, poly.polyint, [1, 1], 1.)
  273. # test integration of zero polynomial
  274. for i in range(2, 5):
  275. k = [0]*(i - 2) + [1]
  276. res = poly.polyint([0], m=i, k=k)
  277. assert_almost_equal(res, [0, 1])
  278. # check single integration with integration constant
  279. for i in range(5):
  280. scl = i + 1
  281. pol = [0]*i + [1]
  282. tgt = [i] + [0]*i + [1/scl]
  283. res = poly.polyint(pol, m=1, k=[i])
  284. assert_almost_equal(trim(res), trim(tgt))
  285. # check single integration with integration constant and lbnd
  286. for i in range(5):
  287. scl = i + 1
  288. pol = [0]*i + [1]
  289. res = poly.polyint(pol, m=1, k=[i], lbnd=-1)
  290. assert_almost_equal(poly.polyval(-1, res), i)
  291. # check single integration with integration constant and scaling
  292. for i in range(5):
  293. scl = i + 1
  294. pol = [0]*i + [1]
  295. tgt = [i] + [0]*i + [2/scl]
  296. res = poly.polyint(pol, m=1, k=[i], scl=2)
  297. assert_almost_equal(trim(res), trim(tgt))
  298. # check multiple integrations with default k
  299. for i in range(5):
  300. for j in range(2, 5):
  301. pol = [0]*i + [1]
  302. tgt = pol[:]
  303. for k in range(j):
  304. tgt = poly.polyint(tgt, m=1)
  305. res = poly.polyint(pol, m=j)
  306. assert_almost_equal(trim(res), trim(tgt))
  307. # check multiple integrations with defined k
  308. for i in range(5):
  309. for j in range(2, 5):
  310. pol = [0]*i + [1]
  311. tgt = pol[:]
  312. for k in range(j):
  313. tgt = poly.polyint(tgt, m=1, k=[k])
  314. res = poly.polyint(pol, m=j, k=list(range(j)))
  315. assert_almost_equal(trim(res), trim(tgt))
  316. # check multiple integrations with lbnd
  317. for i in range(5):
  318. for j in range(2, 5):
  319. pol = [0]*i + [1]
  320. tgt = pol[:]
  321. for k in range(j):
  322. tgt = poly.polyint(tgt, m=1, k=[k], lbnd=-1)
  323. res = poly.polyint(pol, m=j, k=list(range(j)), lbnd=-1)
  324. assert_almost_equal(trim(res), trim(tgt))
  325. # check multiple integrations with scaling
  326. for i in range(5):
  327. for j in range(2, 5):
  328. pol = [0]*i + [1]
  329. tgt = pol[:]
  330. for k in range(j):
  331. tgt = poly.polyint(tgt, m=1, k=[k], scl=2)
  332. res = poly.polyint(pol, m=j, k=list(range(j)), scl=2)
  333. assert_almost_equal(trim(res), trim(tgt))
  334. def test_polyint_axis(self):
  335. # check that axis keyword works
  336. c2d = np.random.random((3, 4))
  337. tgt = np.vstack([poly.polyint(c) for c in c2d.T]).T
  338. res = poly.polyint(c2d, axis=0)
  339. assert_almost_equal(res, tgt)
  340. tgt = np.vstack([poly.polyint(c) for c in c2d])
  341. res = poly.polyint(c2d, axis=1)
  342. assert_almost_equal(res, tgt)
  343. tgt = np.vstack([poly.polyint(c, k=3) for c in c2d])
  344. res = poly.polyint(c2d, k=3, axis=1)
  345. assert_almost_equal(res, tgt)
  346. class TestDerivative:
  347. def test_polyder(self):
  348. # check exceptions
  349. assert_raises(TypeError, poly.polyder, [0], .5)
  350. assert_raises(ValueError, poly.polyder, [0], -1)
  351. # check that zeroth derivative does nothing
  352. for i in range(5):
  353. tgt = [0]*i + [1]
  354. res = poly.polyder(tgt, m=0)
  355. assert_equal(trim(res), trim(tgt))
  356. # check that derivation is the inverse of integration
  357. for i in range(5):
  358. for j in range(2, 5):
  359. tgt = [0]*i + [1]
  360. res = poly.polyder(poly.polyint(tgt, m=j), m=j)
  361. assert_almost_equal(trim(res), trim(tgt))
  362. # check derivation with scaling
  363. for i in range(5):
  364. for j in range(2, 5):
  365. tgt = [0]*i + [1]
  366. res = poly.polyder(poly.polyint(tgt, m=j, scl=2), m=j, scl=.5)
  367. assert_almost_equal(trim(res), trim(tgt))
  368. def test_polyder_axis(self):
  369. # check that axis keyword works
  370. c2d = np.random.random((3, 4))
  371. tgt = np.vstack([poly.polyder(c) for c in c2d.T]).T
  372. res = poly.polyder(c2d, axis=0)
  373. assert_almost_equal(res, tgt)
  374. tgt = np.vstack([poly.polyder(c) for c in c2d])
  375. res = poly.polyder(c2d, axis=1)
  376. assert_almost_equal(res, tgt)
  377. class TestVander:
  378. # some random values in [-1, 1)
  379. x = np.random.random((3, 5))*2 - 1
  380. def test_polyvander(self):
  381. # check for 1d x
  382. x = np.arange(3)
  383. v = poly.polyvander(x, 3)
  384. assert_(v.shape == (3, 4))
  385. for i in range(4):
  386. coef = [0]*i + [1]
  387. assert_almost_equal(v[..., i], poly.polyval(x, coef))
  388. # check for 2d x
  389. x = np.array([[1, 2], [3, 4], [5, 6]])
  390. v = poly.polyvander(x, 3)
  391. assert_(v.shape == (3, 2, 4))
  392. for i in range(4):
  393. coef = [0]*i + [1]
  394. assert_almost_equal(v[..., i], poly.polyval(x, coef))
  395. def test_polyvander2d(self):
  396. # also tests polyval2d for non-square coefficient array
  397. x1, x2, x3 = self.x
  398. c = np.random.random((2, 3))
  399. van = poly.polyvander2d(x1, x2, [1, 2])
  400. tgt = poly.polyval2d(x1, x2, c)
  401. res = np.dot(van, c.flat)
  402. assert_almost_equal(res, tgt)
  403. # check shape
  404. van = poly.polyvander2d([x1], [x2], [1, 2])
  405. assert_(van.shape == (1, 5, 6))
  406. def test_polyvander3d(self):
  407. # also tests polyval3d for non-square coefficient array
  408. x1, x2, x3 = self.x
  409. c = np.random.random((2, 3, 4))
  410. van = poly.polyvander3d(x1, x2, x3, [1, 2, 3])
  411. tgt = poly.polyval3d(x1, x2, x3, c)
  412. res = np.dot(van, c.flat)
  413. assert_almost_equal(res, tgt)
  414. # check shape
  415. van = poly.polyvander3d([x1], [x2], [x3], [1, 2, 3])
  416. assert_(van.shape == (1, 5, 24))
  417. def test_polyvandernegdeg(self):
  418. x = np.arange(3)
  419. assert_raises(ValueError, poly.polyvander, x, -1)
  420. class TestCompanion:
  421. def test_raises(self):
  422. assert_raises(ValueError, poly.polycompanion, [])
  423. assert_raises(ValueError, poly.polycompanion, [1])
  424. def test_dimensions(self):
  425. for i in range(1, 5):
  426. coef = [0]*i + [1]
  427. assert_(poly.polycompanion(coef).shape == (i, i))
  428. def test_linear_root(self):
  429. assert_(poly.polycompanion([1, 2])[0, 0] == -.5)
  430. class TestMisc:
  431. def test_polyfromroots(self):
  432. res = poly.polyfromroots([])
  433. assert_almost_equal(trim(res), [1])
  434. for i in range(1, 5):
  435. roots = np.cos(np.linspace(-np.pi, 0, 2*i + 1)[1::2])
  436. tgt = Tlist[i]
  437. res = poly.polyfromroots(roots)*2**(i-1)
  438. assert_almost_equal(trim(res), trim(tgt))
  439. def test_polyroots(self):
  440. assert_almost_equal(poly.polyroots([1]), [])
  441. assert_almost_equal(poly.polyroots([1, 2]), [-.5])
  442. for i in range(2, 5):
  443. tgt = np.linspace(-1, 1, i)
  444. res = poly.polyroots(poly.polyfromroots(tgt))
  445. assert_almost_equal(trim(res), trim(tgt))
  446. def test_polyfit(self):
  447. def f(x):
  448. return x*(x - 1)*(x - 2)
  449. def f2(x):
  450. return x**4 + x**2 + 1
  451. # Test exceptions
  452. assert_raises(ValueError, poly.polyfit, [1], [1], -1)
  453. assert_raises(TypeError, poly.polyfit, [[1]], [1], 0)
  454. assert_raises(TypeError, poly.polyfit, [], [1], 0)
  455. assert_raises(TypeError, poly.polyfit, [1], [[[1]]], 0)
  456. assert_raises(TypeError, poly.polyfit, [1, 2], [1], 0)
  457. assert_raises(TypeError, poly.polyfit, [1], [1, 2], 0)
  458. assert_raises(TypeError, poly.polyfit, [1], [1], 0, w=[[1]])
  459. assert_raises(TypeError, poly.polyfit, [1], [1], 0, w=[1, 1])
  460. assert_raises(ValueError, poly.polyfit, [1], [1], [-1,])
  461. assert_raises(ValueError, poly.polyfit, [1], [1], [2, -1, 6])
  462. assert_raises(TypeError, poly.polyfit, [1], [1], [])
  463. # Test fit
  464. x = np.linspace(0, 2)
  465. y = f(x)
  466. #
  467. coef3 = poly.polyfit(x, y, 3)
  468. assert_equal(len(coef3), 4)
  469. assert_almost_equal(poly.polyval(x, coef3), y)
  470. coef3 = poly.polyfit(x, y, [0, 1, 2, 3])
  471. assert_equal(len(coef3), 4)
  472. assert_almost_equal(poly.polyval(x, coef3), y)
  473. #
  474. coef4 = poly.polyfit(x, y, 4)
  475. assert_equal(len(coef4), 5)
  476. assert_almost_equal(poly.polyval(x, coef4), y)
  477. coef4 = poly.polyfit(x, y, [0, 1, 2, 3, 4])
  478. assert_equal(len(coef4), 5)
  479. assert_almost_equal(poly.polyval(x, coef4), y)
  480. #
  481. coef2d = poly.polyfit(x, np.array([y, y]).T, 3)
  482. assert_almost_equal(coef2d, np.array([coef3, coef3]).T)
  483. coef2d = poly.polyfit(x, np.array([y, y]).T, [0, 1, 2, 3])
  484. assert_almost_equal(coef2d, np.array([coef3, coef3]).T)
  485. # test weighting
  486. w = np.zeros_like(x)
  487. yw = y.copy()
  488. w[1::2] = 1
  489. yw[0::2] = 0
  490. wcoef3 = poly.polyfit(x, yw, 3, w=w)
  491. assert_almost_equal(wcoef3, coef3)
  492. wcoef3 = poly.polyfit(x, yw, [0, 1, 2, 3], w=w)
  493. assert_almost_equal(wcoef3, coef3)
  494. #
  495. wcoef2d = poly.polyfit(x, np.array([yw, yw]).T, 3, w=w)
  496. assert_almost_equal(wcoef2d, np.array([coef3, coef3]).T)
  497. wcoef2d = poly.polyfit(x, np.array([yw, yw]).T, [0, 1, 2, 3], w=w)
  498. assert_almost_equal(wcoef2d, np.array([coef3, coef3]).T)
  499. # test scaling with complex values x points whose square
  500. # is zero when summed.
  501. x = [1, 1j, -1, -1j]
  502. assert_almost_equal(poly.polyfit(x, x, 1), [0, 1])
  503. assert_almost_equal(poly.polyfit(x, x, [0, 1]), [0, 1])
  504. # test fitting only even Polyendre polynomials
  505. x = np.linspace(-1, 1)
  506. y = f2(x)
  507. coef1 = poly.polyfit(x, y, 4)
  508. assert_almost_equal(poly.polyval(x, coef1), y)
  509. coef2 = poly.polyfit(x, y, [0, 2, 4])
  510. assert_almost_equal(poly.polyval(x, coef2), y)
  511. assert_almost_equal(coef1, coef2)
  512. def test_polytrim(self):
  513. coef = [2, -1, 1, 0]
  514. # Test exceptions
  515. assert_raises(ValueError, poly.polytrim, coef, -1)
  516. # Test results
  517. assert_equal(poly.polytrim(coef), coef[:-1])
  518. assert_equal(poly.polytrim(coef, 1), coef[:-3])
  519. assert_equal(poly.polytrim(coef, 2), [0])
  520. def test_polyline(self):
  521. assert_equal(poly.polyline(3, 4), [3, 4])
  522. def test_polyline_zero(self):
  523. assert_equal(poly.polyline(3, 0), [3])
  524. def test_fit_degenerate_domain(self):
  525. p = poly.Polynomial.fit([1], [2], deg=0)
  526. assert_equal(p.coef, [2.])
  527. p = poly.Polynomial.fit([1, 1], [2, 2.1], deg=0)
  528. assert_almost_equal(p.coef, [2.05])
  529. with assert_warns(pu.RankWarning):
  530. p = poly.Polynomial.fit([1, 1], [2, 2.1], deg=1)
  531. def test_result_type(self):
  532. w = np.array([-1, 1], dtype=np.float32)
  533. p = np.polynomial.Polynomial(w, domain=w, window=w)
  534. v = p(2)
  535. assert_equal(v.dtype, np.float32)
  536. arr = np.polydiv(1, np.float32(1))
  537. assert_equal(arr[0].dtype, np.float64)