test_polysys.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. """Tests for solvers of systems of polynomial equations. """
  2. from sympy.polys.domains import ZZ, QQ_I
  3. from sympy.core.numbers import (I, Integer, Rational)
  4. from sympy.core.singleton import S
  5. from sympy.core.symbol import symbols
  6. from sympy.functions.elementary.miscellaneous import sqrt
  7. from sympy.polys.domains.rationalfield import QQ
  8. from sympy.polys.polyerrors import UnsolvableFactorError
  9. from sympy.polys.polyoptions import Options
  10. from sympy.polys.polytools import Poly
  11. from sympy.polys.rootoftools import CRootOf
  12. from sympy.solvers.solvers import solve
  13. from sympy.utilities.iterables import flatten
  14. from sympy.abc import a, b, c, x, y, z
  15. from sympy.polys import PolynomialError
  16. from sympy.solvers.polysys import (solve_poly_system,
  17. solve_triangulated,
  18. solve_biquadratic, SolveFailed,
  19. solve_generic, factor_system_bool,
  20. factor_system_cond, factor_system_poly,
  21. factor_system, _factor_sets, _factor_sets_slow)
  22. from sympy.polys.polytools import parallel_poly_from_expr
  23. from sympy.testing.pytest import raises
  24. from sympy.core.relational import Eq
  25. from sympy.functions.elementary.trigonometric import sin, cos
  26. from sympy.functions.elementary.exponential import exp
  27. def test_solve_poly_system():
  28. assert solve_poly_system([x - 1], x) == [(S.One,)]
  29. assert solve_poly_system([y - x, y - x - 1], x, y) is None
  30. assert solve_poly_system([y - x**2, y + x**2], x, y) == [(S.Zero, S.Zero)]
  31. assert solve_poly_system([2*x - 3, y*Rational(3, 2) - 2*x, z - 5*y], x, y, z) == \
  32. [(Rational(3, 2), Integer(2), Integer(10))]
  33. assert solve_poly_system([x*y - 2*y, 2*y**2 - x**2], x, y) == \
  34. [(0, 0), (2, -sqrt(2)), (2, sqrt(2))]
  35. assert solve_poly_system([y - x**2, y + x**2 + 1], x, y) == \
  36. [(-I*sqrt(S.Half), Rational(-1, 2)), (I*sqrt(S.Half), Rational(-1, 2))]
  37. f_1 = x**2 + y + z - 1
  38. f_2 = x + y**2 + z - 1
  39. f_3 = x + y + z**2 - 1
  40. a, b = sqrt(2) - 1, -sqrt(2) - 1
  41. assert solve_poly_system([f_1, f_2, f_3], x, y, z) == \
  42. [(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)]
  43. solution = [(1, -1), (1, 1)]
  44. assert solve_poly_system([Poly(x**2 - y**2), Poly(x - 1)]) == solution
  45. assert solve_poly_system([x**2 - y**2, x - 1], x, y) == solution
  46. assert solve_poly_system([x**2 - y**2, x - 1]) == solution
  47. assert solve_poly_system(
  48. [x + x*y - 3, y + x*y - 4], x, y) == [(-3, -2), (1, 2)]
  49. raises(NotImplementedError, lambda: solve_poly_system([x**3 - y**3], x, y))
  50. raises(NotImplementedError, lambda: solve_poly_system(
  51. [z, -2*x*y**2 + x + y**2*z, y**2*(-z - 4) + 2]))
  52. raises(PolynomialError, lambda: solve_poly_system([1/x], x))
  53. raises(NotImplementedError, lambda: solve_poly_system(
  54. [x-1,], (x, y)))
  55. raises(NotImplementedError, lambda: solve_poly_system(
  56. [y-1,], (x, y)))
  57. # solve_poly_system should ideally construct solutions using
  58. # CRootOf for the following four tests
  59. assert solve_poly_system([x**5 - x + 1], [x], strict=False) == []
  60. raises(UnsolvableFactorError, lambda: solve_poly_system(
  61. [x**5 - x + 1], [x], strict=True))
  62. assert solve_poly_system([(x - 1)*(x**5 - x + 1), y**2 - 1], [x, y],
  63. strict=False) == [(1, -1), (1, 1)]
  64. raises(UnsolvableFactorError,
  65. lambda: solve_poly_system([(x - 1)*(x**5 - x + 1), y**2-1],
  66. [x, y], strict=True))
  67. def test_solve_generic():
  68. NewOption = Options((x, y), {'domain': 'ZZ'})
  69. assert solve_generic([x**2 - 2*y**2, y**2 - y + 1], NewOption) == \
  70. [(-sqrt(-1 - sqrt(3)*I), Rational(1, 2) - sqrt(3)*I/2),
  71. (sqrt(-1 - sqrt(3)*I), Rational(1, 2) - sqrt(3)*I/2),
  72. (-sqrt(-1 + sqrt(3)*I), Rational(1, 2) + sqrt(3)*I/2),
  73. (sqrt(-1 + sqrt(3)*I), Rational(1, 2) + sqrt(3)*I/2)]
  74. # solve_generic should ideally construct solutions using
  75. # CRootOf for the following two tests
  76. assert solve_generic(
  77. [2*x - y, (y - 1)*(y**5 - y + 1)], NewOption, strict=False) == \
  78. [(Rational(1, 2), 1)]
  79. raises(UnsolvableFactorError, lambda: solve_generic(
  80. [2*x - y, (y - 1)*(y**5 - y + 1)], NewOption, strict=True))
  81. def test_solve_biquadratic():
  82. x0, y0, x1, y1, r = symbols('x0 y0 x1 y1 r')
  83. f_1 = (x - 1)**2 + (y - 1)**2 - r**2
  84. f_2 = (x - 2)**2 + (y - 2)**2 - r**2
  85. s = sqrt(2*r**2 - 1)
  86. a = (3 - s)/2
  87. b = (3 + s)/2
  88. assert solve_poly_system([f_1, f_2], x, y) == [(a, b), (b, a)]
  89. f_1 = (x - 1)**2 + (y - 2)**2 - r**2
  90. f_2 = (x - 1)**2 + (y - 1)**2 - r**2
  91. assert solve_poly_system([f_1, f_2], x, y) == \
  92. [(1 - sqrt((2*r - 1)*(2*r + 1))/2, Rational(3, 2)),
  93. (1 + sqrt((2*r - 1)*(2*r + 1))/2, Rational(3, 2))]
  94. query = lambda expr: expr.is_Pow and expr.exp is S.Half
  95. f_1 = (x - 1 )**2 + (y - 2)**2 - r**2
  96. f_2 = (x - x1)**2 + (y - 1)**2 - r**2
  97. result = solve_poly_system([f_1, f_2], x, y)
  98. assert len(result) == 2 and all(len(r) == 2 for r in result)
  99. assert all(r.count(query) == 1 for r in flatten(result))
  100. f_1 = (x - x0)**2 + (y - y0)**2 - r**2
  101. f_2 = (x - x1)**2 + (y - y1)**2 - r**2
  102. result = solve_poly_system([f_1, f_2], x, y)
  103. assert len(result) == 2 and all(len(r) == 2 for r in result)
  104. assert all(len(r.find(query)) == 1 for r in flatten(result))
  105. s1 = (x*y - y, x**2 - x)
  106. assert solve(s1) == [{x: 1}, {x: 0, y: 0}]
  107. s2 = (x*y - x, y**2 - y)
  108. assert solve(s2) == [{y: 1}, {x: 0, y: 0}]
  109. gens = (x, y)
  110. for seq in (s1, s2):
  111. (f, g), opt = parallel_poly_from_expr(seq, *gens)
  112. raises(SolveFailed, lambda: solve_biquadratic(f, g, opt))
  113. seq = (x**2 + y**2 - 2, y**2 - 1)
  114. (f, g), opt = parallel_poly_from_expr(seq, *gens)
  115. assert solve_biquadratic(f, g, opt) == [
  116. (-1, -1), (-1, 1), (1, -1), (1, 1)]
  117. ans = [(0, -1), (0, 1)]
  118. seq = (x**2 + y**2 - 1, y**2 - 1)
  119. (f, g), opt = parallel_poly_from_expr(seq, *gens)
  120. assert solve_biquadratic(f, g, opt) == ans
  121. seq = (x**2 + y**2 - 1, x**2 - x + y**2 - 1)
  122. (f, g), opt = parallel_poly_from_expr(seq, *gens)
  123. assert solve_biquadratic(f, g, opt) == ans
  124. def test_solve_triangulated():
  125. f_1 = x**2 + y + z - 1
  126. f_2 = x + y**2 + z - 1
  127. f_3 = x + y + z**2 - 1
  128. a, b = sqrt(2) - 1, -sqrt(2) - 1
  129. assert solve_triangulated([f_1, f_2, f_3], x, y, z) == \
  130. [(0, 0, 1), (0, 1, 0), (1, 0, 0)]
  131. dom = QQ.algebraic_field(sqrt(2))
  132. assert solve_triangulated([f_1, f_2, f_3], x, y, z, domain=dom) == \
  133. [(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)]
  134. a, b = CRootOf(z**2 + 2*z - 1, 0), CRootOf(z**2 + 2*z - 1, 1)
  135. assert solve_triangulated([f_1, f_2, f_3], x, y, z, extension=True) == \
  136. [(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)]
  137. def test_solve_issue_3686():
  138. roots = solve_poly_system([((x - 5)**2/250000 + (y - Rational(5, 10))**2/250000) - 1, x], x, y)
  139. assert roots == [(0, S.Half - 15*sqrt(1111)), (0, S.Half + 15*sqrt(1111))]
  140. roots = solve_poly_system([((x - 5)**2/250000 + (y - 5.0/10)**2/250000) - 1, x], x, y)
  141. # TODO: does this really have to be so complicated?!
  142. assert len(roots) == 2
  143. assert roots[0][0] == 0
  144. assert roots[0][1].epsilon_eq(-499.474999374969, 1e12)
  145. assert roots[1][0] == 0
  146. assert roots[1][1].epsilon_eq(500.474999374969, 1e12)
  147. def test_factor_system():
  148. assert factor_system([x**2 + 2*x + 1]) == [[x + 1]]
  149. assert factor_system([x**2 + 2*x + 1, y**2 + 2*y + 1]) == [[x + 1, y + 1]]
  150. assert factor_system([x**2 + 1]) == [[x**2 + 1]]
  151. assert factor_system([]) == [[]]
  152. assert factor_system([x**2 + y**2 + 2*x*y, x**2 - 2], extension=sqrt(2)) == [
  153. [x + y, x + sqrt(2)],
  154. [x + y, x - sqrt(2)],
  155. ]
  156. assert factor_system([x**2 + 1, y**2 + 1], gaussian=True) == [
  157. [x + I, y + I],
  158. [x + I, y - I],
  159. [x - I, y + I],
  160. [x - I, y - I],
  161. ]
  162. assert factor_system([x**2 + 1, y**2 + 1], domain=QQ_I) == [
  163. [x + I, y + I],
  164. [x + I, y - I],
  165. [x - I, y + I],
  166. [x - I, y - I],
  167. ]
  168. assert factor_system([0]) == [[]]
  169. assert factor_system([1]) == []
  170. assert factor_system([0 , x]) == [[x]]
  171. assert factor_system([1, 0, x]) == []
  172. assert factor_system([x**4 - 1, y**6 - 1]) == [
  173. [x**2 + 1, y**2 + y + 1],
  174. [x**2 + 1, y**2 - y + 1],
  175. [x**2 + 1, y + 1],
  176. [x**2 + 1, y - 1],
  177. [x + 1, y**2 + y + 1],
  178. [x + 1, y**2 - y + 1],
  179. [x - 1, y**2 + y + 1],
  180. [x - 1, y**2 - y + 1],
  181. [x + 1, y + 1],
  182. [x + 1, y - 1],
  183. [x - 1, y + 1],
  184. [x - 1, y - 1],
  185. ]
  186. assert factor_system([(x - 1)*(y - 2), (y - 2)*(z - 3)]) == [
  187. [x - 1, z - 3],
  188. [y - 2]
  189. ]
  190. assert factor_system([sin(x)**2 + cos(x)**2 - 1, x]) == [
  191. [x, sin(x)**2 + cos(x)**2 - 1],
  192. ]
  193. assert factor_system([sin(x)**2 + cos(x)**2 - 1]) == [
  194. [sin(x)**2 + cos(x)**2 - 1]
  195. ]
  196. assert factor_system([sin(x)**2 + cos(x)**2]) == [
  197. [sin(x)**2 + cos(x)**2]
  198. ]
  199. assert factor_system([a*x, y, a]) == [[y, a]]
  200. assert factor_system([a*x, y, a], [x, y]) == []
  201. assert factor_system([a ** 2 * x, y], [x, y]) == [[x, y]]
  202. assert factor_system([a*x*(x - 1), b*y, c], [x, y]) == []
  203. assert factor_system([a*x*(x - 1), b*y, c], [x, y, c]) == [
  204. [x - 1, y, c],
  205. [x, y, c],
  206. ]
  207. assert factor_system([a*x*(x - 1), b*y, c]) == [
  208. [x - 1, y, c],
  209. [x, y, c],
  210. [x - 1, b, c],
  211. [x, b, c],
  212. [y, a, c],
  213. [a, b, c],
  214. ]
  215. assert factor_system([x**2 - 2], [y]) == []
  216. assert factor_system([x**2 - 2], [x]) == [[x**2 - 2]]
  217. assert factor_system([cos(x)**2 - sin(x)**2, cos(x)**2 + sin(x)**2 - 1]) == [
  218. [sin(x)**2 + cos(x)**2 - 1, sin(x) + cos(x)],
  219. [sin(x)**2 + cos(x)**2 - 1, -sin(x) + cos(x)],
  220. ]
  221. assert factor_system([(cos(x) + sin(x))**2 - 1, cos(x)**2 - sin(x)**2 - cos(2*x)]) == [
  222. [sin(x)**2 - cos(x)**2 + cos(2*x), sin(x) + cos(x) + 1],
  223. [sin(x)**2 - cos(x)**2 + cos(2*x), sin(x) + cos(x) - 1],
  224. ]
  225. assert factor_system([(cos(x) + sin(x))*exp(y) - 1, (cos(x) - sin(x))*exp(y) - 1]) == [
  226. [exp(y)*sin(x) + exp(y)*cos(x) - 1, -exp(y)*sin(x) + exp(y)*cos(x) - 1]
  227. ]
  228. def test_factor_system_poly():
  229. px = lambda e: Poly(e, x)
  230. pxab = lambda e: Poly(e, x, domain=ZZ[a, b])
  231. pxI = lambda e: Poly(e, x, domain=QQ_I)
  232. pxyz = lambda e: Poly(e, (x, y, z))
  233. assert factor_system_poly([px(x**2 - 1), px(x**2 - 4)]) == [
  234. [px(x + 2), px(x + 1)],
  235. [px(x + 2), px(x - 1)],
  236. [px(x + 1), px(x - 2)],
  237. [px(x - 1), px(x - 2)],
  238. ]
  239. assert factor_system_poly([px(x**2 - 1)]) == [[px(x + 1)], [px(x - 1)]]
  240. assert factor_system_poly([pxyz(x**2*y - y), pxyz(x**2*z - z)]) == [
  241. [pxyz(x + 1)],
  242. [pxyz(x - 1)],
  243. [pxyz(y), pxyz(z)],
  244. ]
  245. assert factor_system_poly([px(x**2*(x - 1)**2), px(x*(x - 1))]) == [
  246. [px(x)],
  247. [px(x - 1)],
  248. ]
  249. assert factor_system_poly([pxyz(x**2 + y*x), pxyz(x**2 + z*x)]) == [
  250. [pxyz(x + y), pxyz(x + z)],
  251. [pxyz(x)],
  252. ]
  253. assert factor_system_poly([pxab((a - 1)*(x - 2)), pxab((b - 3)*(x - 2))]) == [
  254. [pxab(x - 2)],
  255. [pxab(a - 1), pxab(b - 3)],
  256. ]
  257. assert factor_system_poly([pxI(x**2 + 1)]) == [[pxI(x + I)], [pxI(x - I)]]
  258. assert factor_system_poly([]) == [[]]
  259. assert factor_system_poly([px(1)]) == []
  260. assert factor_system_poly([px(0), px(x)]) == [[px(x)]]
  261. def test_factor_system_cond():
  262. assert factor_system_cond([x ** 2 - 1, x ** 2 - 4]) == [
  263. [x + 2, x + 1],
  264. [x + 2, x - 1],
  265. [x + 1, x - 2],
  266. [x - 1, x - 2],
  267. ]
  268. assert factor_system_cond([1]) == []
  269. assert factor_system_cond([0]) == [[]]
  270. assert factor_system_cond([1, x]) == []
  271. assert factor_system_cond([0, x]) == [[x]]
  272. assert factor_system_cond([]) == [[]]
  273. assert factor_system_cond([x**2 + y*x]) == [[x + y], [x]]
  274. assert factor_system_cond([(a - 1)*(x - 2), (b - 3)*(x - 2)], [x]) == [
  275. [x - 2],
  276. [a - 1, b - 3],
  277. ]
  278. assert factor_system_cond([a * (x - 1), b], [x]) == [[x - 1, b], [a, b]]
  279. assert factor_system_cond([a*x*(x-1), b*y, c], [x, y]) == [
  280. [x - 1, y, c],
  281. [x, y, c],
  282. [x - 1, b, c],
  283. [x, b, c],
  284. [y, a, c],
  285. [a, b, c],
  286. ]
  287. assert factor_system_cond([x*(x-1), y], [x, y]) == [[x - 1, y], [x, y]]
  288. assert factor_system_cond([a*x, y, a], [x, y]) == [[y, a]]
  289. assert factor_system_cond([a*x, b*x], [x, y]) == [[x], [a, b]]
  290. assert factor_system_cond([a*b*x, y], [x, y]) == [[x, y], [y, a*b]]
  291. assert factor_system_cond([a*b*x, y]) == [[x, y], [y, a], [y, b]]
  292. assert factor_system_cond([a**2*x, y], [x, y]) == [[x, y], [y, a]]
  293. def test_factor_system_bool():
  294. eqs = [a*(x - 1)*(y - 1), b*(x - 2)*(y - 1)*(y - 2)]
  295. assert factor_system_bool(eqs, [x, y]) == (
  296. Eq(y - 1, 0)
  297. | (Eq(a, 0) & Eq(b, 0))
  298. | (Eq(a, 0) & Eq(x - 2, 0))
  299. | (Eq(a, 0) & Eq(y - 2, 0))
  300. | (Eq(b, 0) & Eq(x - 1, 0))
  301. | (Eq(x - 2, 0) & Eq(x - 1, 0))
  302. | (Eq(x - 1, 0) & Eq(y - 2, 0))
  303. )
  304. assert factor_system_bool([x - 1], [x]) == Eq(x - 1, 0)
  305. assert factor_system_bool([(x - 1)*(x - 2)], [x]) == Eq(x - 2, 0) | Eq(x - 1, 0)
  306. assert factor_system_bool([], [x]) == True
  307. assert factor_system_bool([0], [x]) == True
  308. assert factor_system_bool([1], [x]) == False
  309. assert factor_system_bool([a], [x]) == Eq(a, 0)
  310. assert factor_system_bool([a * x, y, a], [x, y]) == Eq(a, 0) & Eq(y, 0)
  311. assert (factor_system_bool([a*x, b*y*x, a], [x, y]) == (
  312. Eq(a, 0) & Eq(b, 0))
  313. | (Eq(a, 0) & Eq(x, 0))
  314. | (Eq(a, 0) & Eq(y, 0)))
  315. assert (factor_system_bool([a*x, b*x], [x, y]) == Eq(x, 0) |
  316. (Eq(a, 0) & Eq(b, 0)))
  317. assert (factor_system_bool([a*b*x, y], [x, y]) == (
  318. Eq(x, 0) & Eq(y, 0)) |
  319. (Eq(y, 0) & Eq(a*b, 0)))
  320. assert (factor_system_bool([a**2*x, y], [x, y]) == (
  321. Eq(a, 0) & Eq(y, 0)) |
  322. (Eq(x, 0) & Eq(y, 0)))
  323. assert factor_system_bool([a*x*y, b*y*z], [x, y, z]) == (
  324. Eq(y, 0)
  325. | (Eq(a, 0) & Eq(b, 0))
  326. | (Eq(a, 0) & Eq(z, 0))
  327. | (Eq(b, 0) & Eq(x, 0))
  328. | (Eq(x, 0) & Eq(z, 0))
  329. )
  330. assert factor_system_bool([a*(x - 1), b], [x]) == (
  331. (Eq(a, 0) & Eq(b, 0))
  332. | (Eq(x - 1, 0) & Eq(b, 0))
  333. )
  334. def test_factor_sets():
  335. #
  336. from random import randint
  337. def generate_random_system(n_eqs=3, n_factors=2, max_val=10):
  338. return [
  339. [randint(0, max_val) for _ in range(randint(1, n_factors))]
  340. for _ in range(n_eqs)
  341. ]
  342. test_cases = [
  343. [[1, 2], [1, 3]],
  344. [[1, 2], [3, 4]],
  345. [[1], [1, 2], [2]],
  346. ]
  347. for case in test_cases:
  348. assert _factor_sets(case) == _factor_sets_slow(case)
  349. for _ in range(100):
  350. system = generate_random_system()
  351. assert _factor_sets(system) == _factor_sets_slow(system)