test_polyutils.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. """Tests for polyutils module.
  2. """
  3. import numpy as np
  4. import numpy.polynomial.polyutils as pu
  5. from numpy.testing import (
  6. assert_almost_equal, assert_raises, assert_equal, assert_,
  7. )
  8. class TestMisc:
  9. def test_trimseq(self):
  10. tgt = [1]
  11. for num_trailing_zeros in range(5):
  12. res = pu.trimseq([1] + [0] * num_trailing_zeros)
  13. assert_equal(res, tgt)
  14. def test_trimseq_empty_input(self):
  15. for empty_seq in [[], np.array([], dtype=np.int32)]:
  16. assert_equal(pu.trimseq(empty_seq), empty_seq)
  17. def test_as_series(self):
  18. # check exceptions
  19. assert_raises(ValueError, pu.as_series, [[]])
  20. assert_raises(ValueError, pu.as_series, [[[1, 2]]])
  21. assert_raises(ValueError, pu.as_series, [[1], ['a']])
  22. # check common types
  23. types = ['i', 'd', 'O']
  24. for i in range(len(types)):
  25. for j in range(i):
  26. ci = np.ones(1, types[i])
  27. cj = np.ones(1, types[j])
  28. [resi, resj] = pu.as_series([ci, cj])
  29. assert_(resi.dtype.char == resj.dtype.char)
  30. assert_(resj.dtype.char == types[i])
  31. def test_trimcoef(self):
  32. coef = [2, -1, 1, 0]
  33. # Test exceptions
  34. assert_raises(ValueError, pu.trimcoef, coef, -1)
  35. # Test results
  36. assert_equal(pu.trimcoef(coef), coef[:-1])
  37. assert_equal(pu.trimcoef(coef, 1), coef[:-3])
  38. assert_equal(pu.trimcoef(coef, 2), [0])
  39. def test_vander_nd_exception(self):
  40. # n_dims != len(points)
  41. assert_raises(ValueError, pu._vander_nd, (), (1, 2, 3), [90])
  42. # n_dims != len(degrees)
  43. assert_raises(ValueError, pu._vander_nd, (), (), [90.65])
  44. # n_dims == 0
  45. assert_raises(ValueError, pu._vander_nd, (), (), [])
  46. def test_div_zerodiv(self):
  47. # c2[-1] == 0
  48. assert_raises(ZeroDivisionError, pu._div, pu._div, (1, 2, 3), [0])
  49. def test_pow_too_large(self):
  50. # power > maxpower
  51. assert_raises(ValueError, pu._pow, (), [1, 2, 3], 5, 4)
  52. class TestDomain:
  53. def test_getdomain(self):
  54. # test for real values
  55. x = [1, 10, 3, -1]
  56. tgt = [-1, 10]
  57. res = pu.getdomain(x)
  58. assert_almost_equal(res, tgt)
  59. # test for complex values
  60. x = [1 + 1j, 1 - 1j, 0, 2]
  61. tgt = [-1j, 2 + 1j]
  62. res = pu.getdomain(x)
  63. assert_almost_equal(res, tgt)
  64. def test_mapdomain(self):
  65. # test for real values
  66. dom1 = [0, 4]
  67. dom2 = [1, 3]
  68. tgt = dom2
  69. res = pu.mapdomain(dom1, dom1, dom2)
  70. assert_almost_equal(res, tgt)
  71. # test for complex values
  72. dom1 = [0 - 1j, 2 + 1j]
  73. dom2 = [-2, 2]
  74. tgt = dom2
  75. x = dom1
  76. res = pu.mapdomain(x, dom1, dom2)
  77. assert_almost_equal(res, tgt)
  78. # test for multidimensional arrays
  79. dom1 = [0, 4]
  80. dom2 = [1, 3]
  81. tgt = np.array([dom2, dom2])
  82. x = np.array([dom1, dom1])
  83. res = pu.mapdomain(x, dom1, dom2)
  84. assert_almost_equal(res, tgt)
  85. # test that subtypes are preserved.
  86. class MyNDArray(np.ndarray):
  87. pass
  88. dom1 = [0, 4]
  89. dom2 = [1, 3]
  90. x = np.array([dom1, dom1]).view(MyNDArray)
  91. res = pu.mapdomain(x, dom1, dom2)
  92. assert_(isinstance(res, MyNDArray))
  93. def test_mapparms(self):
  94. # test for real values
  95. dom1 = [0, 4]
  96. dom2 = [1, 3]
  97. tgt = [1, .5]
  98. res = pu. mapparms(dom1, dom2)
  99. assert_almost_equal(res, tgt)
  100. # test for complex values
  101. dom1 = [0 - 1j, 2 + 1j]
  102. dom2 = [-2, 2]
  103. tgt = [-1 + 1j, 1 - 1j]
  104. res = pu.mapparms(dom1, dom2)
  105. assert_almost_equal(res, tgt)