test_linsolve.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. #
  2. # test_linsolve.py
  3. #
  4. # Test the internal implementation of linsolve.
  5. #
  6. from sympy.testing.pytest import raises
  7. from sympy.core.numbers import I
  8. from sympy.core.relational import Eq
  9. from sympy.core.singleton import S
  10. from sympy.abc import x, y, z
  11. from sympy.polys.matrices.linsolve import _linsolve
  12. from sympy.polys.solvers import PolyNonlinearError
  13. def test__linsolve():
  14. assert _linsolve([], [x]) == {x:x}
  15. assert _linsolve([S.Zero], [x]) == {x:x}
  16. assert _linsolve([x-1,x-2], [x]) is None
  17. assert _linsolve([x-1], [x]) == {x:1}
  18. assert _linsolve([x-1, y], [x, y]) == {x:1, y:S.Zero}
  19. assert _linsolve([2*I], [x]) is None
  20. raises(PolyNonlinearError, lambda: _linsolve([x*(1 + x)], [x]))
  21. def test__linsolve_float():
  22. # This should give the exact answer:
  23. eqs = [
  24. y - x,
  25. y - 0.0216 * x
  26. ]
  27. # Should _linsolve return floats here?
  28. sol = {x:0, y:0}
  29. assert _linsolve(eqs, (x, y)) == sol
  30. # Other cases should be close to eps
  31. def all_close(sol1, sol2, eps=1e-15):
  32. close = lambda a, b: abs(a - b) < eps
  33. assert sol1.keys() == sol2.keys()
  34. return all(close(sol1[s], sol2[s]) for s in sol1)
  35. eqs = [
  36. 0.8*x + 0.8*z + 0.2,
  37. 0.9*x + 0.7*y + 0.2*z + 0.9,
  38. 0.7*x + 0.2*y + 0.2*z + 0.5
  39. ]
  40. sol_exact = {x:-29/42, y:-11/21, z:37/84}
  41. sol_linsolve = _linsolve(eqs, [x,y,z])
  42. assert all_close(sol_exact, sol_linsolve)
  43. eqs = [
  44. 0.9*x + 0.3*y + 0.4*z + 0.6,
  45. 0.6*x + 0.9*y + 0.1*z + 0.7,
  46. 0.4*x + 0.6*y + 0.9*z + 0.5
  47. ]
  48. sol_exact = {x:-88/175, y:-46/105, z:-1/25}
  49. sol_linsolve = _linsolve(eqs, [x,y,z])
  50. assert all_close(sol_exact, sol_linsolve)
  51. eqs = [
  52. 0.4*x + 0.3*y + 0.6*z + 0.7,
  53. 0.4*x + 0.3*y + 0.9*z + 0.9,
  54. 0.7*x + 0.9*y,
  55. ]
  56. sol_exact = {x:-9/5, y:7/5, z:-2/3}
  57. sol_linsolve = _linsolve(eqs, [x,y,z])
  58. assert all_close(sol_exact, sol_linsolve)
  59. eqs = [
  60. x*(0.7 + 0.6*I) + y*(0.4 + 0.7*I) + z*(0.9 + 0.1*I) + 0.5,
  61. 0.2*I*x + 0.2*I*y + z*(0.9 + 0.2*I) + 0.1,
  62. x*(0.9 + 0.7*I) + y*(0.9 + 0.7*I) + z*(0.9 + 0.4*I) + 0.4,
  63. ]
  64. sol_exact = {
  65. x:-6157/7995 - 411/5330*I,
  66. y:8519/15990 + 1784/7995*I,
  67. z:-34/533 + 107/1599*I,
  68. }
  69. sol_linsolve = _linsolve(eqs, [x,y,z])
  70. assert all_close(sol_exact, sol_linsolve)
  71. # XXX: This system for x and y over RR(z) is problematic.
  72. #
  73. # eqs = [
  74. # x*(0.2*z + 0.9) + y*(0.5*z + 0.8) + 0.6,
  75. # 0.1*x*z + y*(0.1*z + 0.6) + 0.9,
  76. # ]
  77. #
  78. # linsolve(eqs, [x, y])
  79. # The solution for x comes out as
  80. #
  81. # -3.9e-5*z**2 - 3.6e-5*z - 8.67361737988404e-20
  82. # x = ----------------------------------------------
  83. # 3.0e-6*z**3 - 1.3e-5*z**2 - 5.4e-5*z
  84. #
  85. # The 8e-20 in the numerator should be zero which would allow z to cancel
  86. # from top and bottom. It should be possible to avoid this somehow because
  87. # the inverse of the matrix only has a quadratic factor (the determinant)
  88. # in the denominator.
  89. def test__linsolve_deprecated():
  90. raises(PolyNonlinearError, lambda:
  91. _linsolve([Eq(x**2, x**2 + y)], [x, y]))
  92. raises(PolyNonlinearError, lambda:
  93. _linsolve([(x + y)**2 - x**2], [x]))
  94. raises(PolyNonlinearError, lambda:
  95. _linsolve([Eq((x + y)**2, x**2)], [x]))