test_solvers.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. import pytest
  2. from sympy.core.function import expand_mul
  3. from sympy.core.numbers import (I, Rational)
  4. from sympy.core.singleton import S
  5. from sympy.core.symbol import (Symbol, symbols)
  6. from sympy.core.sympify import sympify
  7. from sympy.simplify.simplify import simplify
  8. from sympy.matrices.exceptions import (ShapeError, NonSquareMatrixError)
  9. from sympy.matrices import (
  10. ImmutableMatrix, Matrix, eye, ones, ImmutableDenseMatrix, dotprodsimp)
  11. from sympy.matrices.determinant import _det_laplace
  12. from sympy.testing.pytest import raises
  13. from sympy.matrices.exceptions import NonInvertibleMatrixError
  14. from sympy.polys.matrices.exceptions import DMShapeError
  15. from sympy.solvers.solveset import linsolve
  16. from sympy.abc import x, y
  17. def test_issue_17247_expression_blowup_29():
  18. M = Matrix(S('''[
  19. [ -3/4, 45/32 - 37*I/16, 0, 0],
  20. [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128],
  21. [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0],
  22. [ 0, 0, 0, -177/128 - 1369*I/128]]'''))
  23. with dotprodsimp(True):
  24. assert M.gauss_jordan_solve(ones(4, 1)) == (Matrix(S('''[
  25. [ -32549314808672/3306971225785 - 17397006745216*I/3306971225785],
  26. [ 67439348256/3306971225785 - 9167503335872*I/3306971225785],
  27. [-15091965363354518272/21217636514687010905 + 16890163109293858304*I/21217636514687010905],
  28. [ -11328/952745 + 87616*I/952745]]''')), Matrix(0, 1, []))
  29. def test_issue_17247_expression_blowup_30():
  30. M = Matrix(S('''[
  31. [ -3/4, 45/32 - 37*I/16, 0, 0],
  32. [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128],
  33. [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0],
  34. [ 0, 0, 0, -177/128 - 1369*I/128]]'''))
  35. with dotprodsimp(True):
  36. assert M.cholesky_solve(ones(4, 1)) == Matrix(S('''[
  37. [ -32549314808672/3306971225785 - 17397006745216*I/3306971225785],
  38. [ 67439348256/3306971225785 - 9167503335872*I/3306971225785],
  39. [-15091965363354518272/21217636514687010905 + 16890163109293858304*I/21217636514687010905],
  40. [ -11328/952745 + 87616*I/952745]]'''))
  41. # @XFAIL # This calculation hangs with dotprodsimp.
  42. # def test_issue_17247_expression_blowup_31():
  43. # M = Matrix([
  44. # [x + 1, 1 - x, 0, 0],
  45. # [1 - x, x + 1, 0, x + 1],
  46. # [ 0, 1 - x, x + 1, 0],
  47. # [ 0, 0, 0, x + 1]])
  48. # with dotprodsimp(True):
  49. # assert M.LDLsolve(ones(4, 1)) == Matrix([
  50. # [(x + 1)/(4*x)],
  51. # [(x - 1)/(4*x)],
  52. # [(x + 1)/(4*x)],
  53. # [ 1/(x + 1)]])
  54. def test_LUsolve_iszerofunc():
  55. # taken from https://github.com/sympy/sympy/issues/24679
  56. M = Matrix([[(x + 1)**2 - (x**2 + 2*x + 1), x], [x, 0]])
  57. b = Matrix([1, 1])
  58. is_zero_func = lambda e: False if e._random() else True
  59. x_exp = Matrix([1/x, (1-(-x**2 - 2*x + (x+1)**2 - 1)/x)/x])
  60. assert (x_exp - M.LUsolve(b, iszerofunc=is_zero_func)) == Matrix([0, 0])
  61. def test_issue_17247_expression_blowup_32():
  62. M = Matrix([
  63. [x + 1, 1 - x, 0, 0],
  64. [1 - x, x + 1, 0, x + 1],
  65. [ 0, 1 - x, x + 1, 0],
  66. [ 0, 0, 0, x + 1]])
  67. with dotprodsimp(True):
  68. assert M.LUsolve(ones(4, 1)) == Matrix([
  69. [(x + 1)/(4*x)],
  70. [(x - 1)/(4*x)],
  71. [(x + 1)/(4*x)],
  72. [ 1/(x + 1)]])
  73. def test_LUsolve():
  74. A = Matrix([[2, 3, 5],
  75. [3, 6, 2],
  76. [8, 3, 6]])
  77. x = Matrix(3, 1, [3, 7, 5])
  78. b = A*x
  79. soln = A.LUsolve(b)
  80. assert soln == x
  81. A = Matrix([[0, -1, 2],
  82. [5, 10, 7],
  83. [8, 3, 4]])
  84. x = Matrix(3, 1, [-1, 2, 5])
  85. b = A*x
  86. soln = A.LUsolve(b)
  87. assert soln == x
  88. A = Matrix([[2, 1], [1, 0], [1, 0]]) # issue 14548
  89. b = Matrix([3, 1, 1])
  90. assert A.LUsolve(b) == Matrix([1, 1])
  91. b = Matrix([3, 1, 2]) # inconsistent
  92. raises(ValueError, lambda: A.LUsolve(b))
  93. A = Matrix([[0, -1, 2],
  94. [5, 10, 7],
  95. [8, 3, 4],
  96. [2, 3, 5],
  97. [3, 6, 2],
  98. [8, 3, 6]])
  99. x = Matrix([2, 1, -4])
  100. b = A*x
  101. soln = A.LUsolve(b)
  102. assert soln == x
  103. A = Matrix([[0, -1, 2], [5, 10, 7]]) # underdetermined
  104. x = Matrix([-1, 2, 0])
  105. b = A*x
  106. raises(NotImplementedError, lambda: A.LUsolve(b))
  107. A = Matrix(4, 4, lambda i, j: 1/(i+j+1) if i != 3 else 0)
  108. b = Matrix.zeros(4, 1)
  109. raises(NonInvertibleMatrixError, lambda: A.LUsolve(b))
  110. def test_LUsolve_noncommutative():
  111. a0, a1, a2, a3 = symbols("a:4", commutative=False)
  112. b0, b1 = symbols("b:2", commutative=False)
  113. A = Matrix([[a0, a1], [a2, a3]])
  114. check = A * A.LUsolve(Matrix([b0, b1]))
  115. assert check[0, 0].expand() == b0
  116. # Because sympy simplification is very limited with noncommutative expressions,
  117. # perform an explicit check with the second element
  118. assert check[1, 0] == (
  119. a2*a0**(-1)*(-a1*(-a2*a0**(-1)*a1 + a3)**(-1)*(-a2*a0**(-1)*b0 + b1) + b0)
  120. + a3*(-a2*a0**(-1)*a1 + a3)**(-1)*(-a2*a0**(-1)*b0 + b1)
  121. )
  122. def test_QRsolve():
  123. A = Matrix([[2, 3, 5],
  124. [3, 6, 2],
  125. [8, 3, 6]])
  126. x = Matrix(3, 1, [3, 7, 5])
  127. b = A*x
  128. soln = A.QRsolve(b)
  129. assert soln == x
  130. x = Matrix([[1, 2], [3, 4], [5, 6]])
  131. b = A*x
  132. soln = A.QRsolve(b)
  133. assert soln == x
  134. A = Matrix([[0, -1, 2],
  135. [5, 10, 7],
  136. [8, 3, 4]])
  137. x = Matrix(3, 1, [-1, 2, 5])
  138. b = A*x
  139. soln = A.QRsolve(b)
  140. assert soln == x
  141. x = Matrix([[7, 8], [9, 10], [11, 12]])
  142. b = A*x
  143. soln = A.QRsolve(b)
  144. assert soln == x
  145. def test_errors():
  146. raises(ShapeError, lambda: Matrix([1]).LUsolve(Matrix([[1, 2], [3, 4]])))
  147. def test_cholesky_solve():
  148. A = Matrix([[2, 3, 5],
  149. [3, 6, 2],
  150. [8, 3, 6]])
  151. x = Matrix(3, 1, [3, 7, 5])
  152. b = A*x
  153. soln = A.cholesky_solve(b)
  154. assert soln == x
  155. A = Matrix([[0, -1, 2],
  156. [5, 10, 7],
  157. [8, 3, 4]])
  158. x = Matrix(3, 1, [-1, 2, 5])
  159. b = A*x
  160. soln = A.cholesky_solve(b)
  161. assert soln == x
  162. A = Matrix(((1, 5), (5, 1)))
  163. x = Matrix((4, -3))
  164. b = A*x
  165. soln = A.cholesky_solve(b)
  166. assert soln == x
  167. A = Matrix(((9, 3*I), (-3*I, 5)))
  168. x = Matrix((-2, 1))
  169. b = A*x
  170. soln = A.cholesky_solve(b)
  171. assert expand_mul(soln) == x
  172. A = Matrix(((9*I, 3), (-3 + I, 5)))
  173. x = Matrix((2 + 3*I, -1))
  174. b = A*x
  175. soln = A.cholesky_solve(b)
  176. assert expand_mul(soln) == x
  177. a00, a01, a11, b0, b1 = symbols('a00, a01, a11, b0, b1')
  178. A = Matrix(((a00, a01), (a01, a11)))
  179. b = Matrix((b0, b1))
  180. x = A.cholesky_solve(b)
  181. assert simplify(A*x) == b
  182. def test_LDLsolve():
  183. A = Matrix([[2, 3, 5],
  184. [3, 6, 2],
  185. [8, 3, 6]])
  186. x = Matrix(3, 1, [3, 7, 5])
  187. b = A*x
  188. soln = A.LDLsolve(b)
  189. assert soln == x
  190. A = Matrix([[0, -1, 2],
  191. [5, 10, 7],
  192. [8, 3, 4]])
  193. x = Matrix(3, 1, [-1, 2, 5])
  194. b = A*x
  195. soln = A.LDLsolve(b)
  196. assert soln == x
  197. A = Matrix(((9, 3*I), (-3*I, 5)))
  198. x = Matrix((-2, 1))
  199. b = A*x
  200. soln = A.LDLsolve(b)
  201. assert expand_mul(soln) == x
  202. A = Matrix(((9*I, 3), (-3 + I, 5)))
  203. x = Matrix((2 + 3*I, -1))
  204. b = A*x
  205. soln = A.LDLsolve(b)
  206. assert expand_mul(soln) == x
  207. A = Matrix(((9, 3), (3, 9)))
  208. x = Matrix((1, 1))
  209. b = A * x
  210. soln = A.LDLsolve(b)
  211. assert expand_mul(soln) == x
  212. A = Matrix([[-5, -3, -4], [-3, -7, 7]])
  213. x = Matrix([[8], [7], [-2]])
  214. b = A * x
  215. raises(NotImplementedError, lambda: A.LDLsolve(b))
  216. def test_lower_triangular_solve():
  217. raises(NonSquareMatrixError,
  218. lambda: Matrix([1, 0]).lower_triangular_solve(Matrix([0, 1])))
  219. raises(ShapeError,
  220. lambda: Matrix([[1, 0], [0, 1]]).lower_triangular_solve(Matrix([1])))
  221. raises(ValueError,
  222. lambda: Matrix([[2, 1], [1, 2]]).lower_triangular_solve(
  223. Matrix([[1, 0], [0, 1]])))
  224. A = Matrix([[1, 0], [0, 1]])
  225. B = Matrix([[x, y], [y, x]])
  226. C = Matrix([[4, 8], [2, 9]])
  227. assert A.lower_triangular_solve(B) == B
  228. assert A.lower_triangular_solve(C) == C
  229. def test_upper_triangular_solve():
  230. raises(NonSquareMatrixError,
  231. lambda: Matrix([1, 0]).upper_triangular_solve(Matrix([0, 1])))
  232. raises(ShapeError,
  233. lambda: Matrix([[1, 0], [0, 1]]).upper_triangular_solve(Matrix([1])))
  234. raises(TypeError,
  235. lambda: Matrix([[2, 1], [1, 2]]).upper_triangular_solve(
  236. Matrix([[1, 0], [0, 1]])))
  237. A = Matrix([[1, 0], [0, 1]])
  238. B = Matrix([[x, y], [y, x]])
  239. C = Matrix([[2, 4], [3, 8]])
  240. assert A.upper_triangular_solve(B) == B
  241. assert A.upper_triangular_solve(C) == C
  242. def test_diagonal_solve():
  243. raises(TypeError, lambda: Matrix([1, 1]).diagonal_solve(Matrix([1])))
  244. A = Matrix([[1, 0], [0, 1]])*2
  245. B = Matrix([[x, y], [y, x]])
  246. assert A.diagonal_solve(B) == B/2
  247. A = Matrix([[1, 0], [1, 2]])
  248. raises(TypeError, lambda: A.diagonal_solve(B))
  249. def test_pinv_solve():
  250. # Fully determined system (unique result, identical to other solvers).
  251. A = Matrix([[1, 5], [7, 9]])
  252. B = Matrix([12, 13])
  253. assert A.pinv_solve(B) == A.cholesky_solve(B)
  254. assert A.pinv_solve(B) == A.LDLsolve(B)
  255. assert A.pinv_solve(B) == Matrix([sympify('-43/26'), sympify('71/26')])
  256. assert A * A.pinv() * B == B
  257. # Fully determined, with two-dimensional B matrix.
  258. B = Matrix([[12, 13, 14], [15, 16, 17]])
  259. assert A.pinv_solve(B) == A.cholesky_solve(B)
  260. assert A.pinv_solve(B) == A.LDLsolve(B)
  261. assert A.pinv_solve(B) == Matrix([[-33, -37, -41], [69, 75, 81]]) / 26
  262. assert A * A.pinv() * B == B
  263. # Underdetermined system (infinite results).
  264. A = Matrix([[1, 0, 1], [0, 1, 1]])
  265. B = Matrix([5, 7])
  266. solution = A.pinv_solve(B)
  267. w = {}
  268. for s in solution.atoms(Symbol):
  269. # Extract dummy symbols used in the solution.
  270. w[s.name] = s
  271. assert solution == Matrix([[w['w0_0']/3 + w['w1_0']/3 - w['w2_0']/3 + 1],
  272. [w['w0_0']/3 + w['w1_0']/3 - w['w2_0']/3 + 3],
  273. [-w['w0_0']/3 - w['w1_0']/3 + w['w2_0']/3 + 4]])
  274. assert A * A.pinv() * B == B
  275. # Overdetermined system (least squares results).
  276. A = Matrix([[1, 0], [0, 0], [0, 1]])
  277. B = Matrix([3, 2, 1])
  278. assert A.pinv_solve(B) == Matrix([3, 1])
  279. # Proof the solution is not exact.
  280. assert A * A.pinv() * B != B
  281. def test_pinv_rank_deficient():
  282. # Test the four properties of the pseudoinverse for various matrices.
  283. As = [Matrix([[1, 1, 1], [2, 2, 2]]),
  284. Matrix([[1, 0], [0, 0]]),
  285. Matrix([[1, 2], [2, 4], [3, 6]])]
  286. for A in As:
  287. A_pinv = A.pinv(method="RD")
  288. AAp = A * A_pinv
  289. ApA = A_pinv * A
  290. assert simplify(AAp * A) == A
  291. assert simplify(ApA * A_pinv) == A_pinv
  292. assert AAp.H == AAp
  293. assert ApA.H == ApA
  294. for A in As:
  295. A_pinv = A.pinv(method="ED")
  296. AAp = A * A_pinv
  297. ApA = A_pinv * A
  298. assert simplify(AAp * A) == A
  299. assert simplify(ApA * A_pinv) == A_pinv
  300. assert AAp.H == AAp
  301. assert ApA.H == ApA
  302. # Test solving with rank-deficient matrices.
  303. A = Matrix([[1, 0], [0, 0]])
  304. # Exact, non-unique solution.
  305. B = Matrix([3, 0])
  306. solution = A.pinv_solve(B)
  307. w1 = solution.atoms(Symbol).pop()
  308. assert w1.name == 'w1_0'
  309. assert solution == Matrix([3, w1])
  310. assert A * A.pinv() * B == B
  311. # Least squares, non-unique solution.
  312. B = Matrix([3, 1])
  313. solution = A.pinv_solve(B)
  314. w1 = solution.atoms(Symbol).pop()
  315. assert w1.name == 'w1_0'
  316. assert solution == Matrix([3, w1])
  317. assert A * A.pinv() * B != B
  318. def test_gauss_jordan_solve():
  319. # Square, full rank, unique solution
  320. A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
  321. b = Matrix([3, 6, 9])
  322. sol, params = A.gauss_jordan_solve(b)
  323. assert sol == Matrix([[-1], [2], [0]])
  324. assert params == Matrix(0, 1, [])
  325. # Square, full rank, unique solution, B has more columns than rows
  326. A = eye(3)
  327. B = Matrix([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
  328. sol, params = A.gauss_jordan_solve(B)
  329. assert sol == B
  330. assert params == Matrix(0, 4, [])
  331. # Square, reduced rank, parametrized solution
  332. A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  333. b = Matrix([3, 6, 9])
  334. sol, params, freevar = A.gauss_jordan_solve(b, freevar=True)
  335. w = {}
  336. for s in sol.atoms(Symbol):
  337. # Extract dummy symbols used in the solution.
  338. w[s.name] = s
  339. assert sol == Matrix([[w['tau0'] - 1], [-2*w['tau0'] + 2], [w['tau0']]])
  340. assert params == Matrix([[w['tau0']]])
  341. assert freevar == [2]
  342. # Square, reduced rank, parametrized solution, B has two columns
  343. A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  344. B = Matrix([[3, 4], [6, 8], [9, 12]])
  345. sol, params, freevar = A.gauss_jordan_solve(B, freevar=True)
  346. w = {}
  347. for s in sol.atoms(Symbol):
  348. # Extract dummy symbols used in the solution.
  349. w[s.name] = s
  350. assert sol == Matrix([[w['tau0'] - 1, w['tau1'] - Rational(4, 3)],
  351. [-2*w['tau0'] + 2, -2*w['tau1'] + Rational(8, 3)],
  352. [w['tau0'], w['tau1']],])
  353. assert params == Matrix([[w['tau0'], w['tau1']]])
  354. assert freevar == [2]
  355. # Square, reduced rank, parametrized solution
  356. A = Matrix([[1, 2, 3], [2, 4, 6], [3, 6, 9]])
  357. b = Matrix([0, 0, 0])
  358. sol, params = A.gauss_jordan_solve(b)
  359. w = {}
  360. for s in sol.atoms(Symbol):
  361. w[s.name] = s
  362. assert sol == Matrix([[-2*w['tau0'] - 3*w['tau1']],
  363. [w['tau0']], [w['tau1']]])
  364. assert params == Matrix([[w['tau0']], [w['tau1']]])
  365. # Square, reduced rank, parametrized solution
  366. A = Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
  367. b = Matrix([0, 0, 0])
  368. sol, params = A.gauss_jordan_solve(b)
  369. w = {}
  370. for s in sol.atoms(Symbol):
  371. w[s.name] = s
  372. assert sol == Matrix([[w['tau0']], [w['tau1']], [w['tau2']]])
  373. assert params == Matrix([[w['tau0']], [w['tau1']], [w['tau2']]])
  374. # Square, reduced rank, no solution
  375. A = Matrix([[1, 2, 3], [2, 4, 6], [3, 6, 9]])
  376. b = Matrix([0, 0, 1])
  377. raises(ValueError, lambda: A.gauss_jordan_solve(b))
  378. # Rectangular, tall, full rank, unique solution
  379. A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]])
  380. b = Matrix([0, 0, 1, 0])
  381. sol, params = A.gauss_jordan_solve(b)
  382. assert sol == Matrix([[Rational(-1, 2)], [0], [Rational(1, 6)]])
  383. assert params == Matrix(0, 1, [])
  384. # Rectangular, tall, full rank, unique solution, B has less columns than rows
  385. A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]])
  386. B = Matrix([[0,0], [0, 0], [1, 2], [0, 0]])
  387. sol, params = A.gauss_jordan_solve(B)
  388. assert sol == Matrix([[Rational(-1, 2), Rational(-2, 2)], [0, 0], [Rational(1, 6), Rational(2, 6)]])
  389. assert params == Matrix(0, 2, [])
  390. # Rectangular, tall, full rank, no solution
  391. A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]])
  392. b = Matrix([0, 0, 0, 1])
  393. raises(ValueError, lambda: A.gauss_jordan_solve(b))
  394. # Rectangular, tall, full rank, no solution, B has two columns (2nd has no solution)
  395. A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]])
  396. B = Matrix([[0,0], [0, 0], [1, 0], [0, 1]])
  397. raises(ValueError, lambda: A.gauss_jordan_solve(B))
  398. # Rectangular, tall, full rank, no solution, B has two columns (1st has no solution)
  399. A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]])
  400. B = Matrix([[0,0], [0, 0], [0, 1], [1, 0]])
  401. raises(ValueError, lambda: A.gauss_jordan_solve(B))
  402. # Rectangular, tall, reduced rank, parametrized solution
  403. A = Matrix([[1, 5, 3], [2, 10, 6], [3, 15, 9], [1, 4, 3]])
  404. b = Matrix([0, 0, 0, 1])
  405. sol, params = A.gauss_jordan_solve(b)
  406. w = {}
  407. for s in sol.atoms(Symbol):
  408. w[s.name] = s
  409. assert sol == Matrix([[-3*w['tau0'] + 5], [-1], [w['tau0']]])
  410. assert params == Matrix([[w['tau0']]])
  411. # Rectangular, tall, reduced rank, no solution
  412. A = Matrix([[1, 5, 3], [2, 10, 6], [3, 15, 9], [1, 4, 3]])
  413. b = Matrix([0, 0, 1, 1])
  414. raises(ValueError, lambda: A.gauss_jordan_solve(b))
  415. # Rectangular, wide, full rank, parametrized solution
  416. A = Matrix([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 1, 12]])
  417. b = Matrix([1, 1, 1])
  418. sol, params = A.gauss_jordan_solve(b)
  419. w = {}
  420. for s in sol.atoms(Symbol):
  421. w[s.name] = s
  422. assert sol == Matrix([[2*w['tau0'] - 1], [-3*w['tau0'] + 1], [0],
  423. [w['tau0']]])
  424. assert params == Matrix([[w['tau0']]])
  425. # Rectangular, wide, reduced rank, parametrized solution
  426. A = Matrix([[1, 2, 3, 4], [5, 6, 7, 8], [2, 4, 6, 8]])
  427. b = Matrix([0, 1, 0])
  428. sol, params = A.gauss_jordan_solve(b)
  429. w = {}
  430. for s in sol.atoms(Symbol):
  431. w[s.name] = s
  432. assert sol == Matrix([[w['tau0'] + 2*w['tau1'] + S.Half],
  433. [-2*w['tau0'] - 3*w['tau1'] - Rational(1, 4)],
  434. [w['tau0']], [w['tau1']]])
  435. assert params == Matrix([[w['tau0']], [w['tau1']]])
  436. # watch out for clashing symbols
  437. x0, x1, x2, _x0 = symbols('_tau0 _tau1 _tau2 tau1')
  438. M = Matrix([[0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, _x0]])
  439. A = M[:, :-1]
  440. b = M[:, -1:]
  441. sol, params = A.gauss_jordan_solve(b)
  442. assert params == Matrix(3, 1, [x0, x1, x2])
  443. assert sol == Matrix(5, 1, [x0, 0, x1, _x0, x2])
  444. # Rectangular, wide, reduced rank, no solution
  445. A = Matrix([[1, 2, 3, 4], [5, 6, 7, 8], [2, 4, 6, 8]])
  446. b = Matrix([1, 1, 1])
  447. raises(ValueError, lambda: A.gauss_jordan_solve(b))
  448. # Test for immutable matrix
  449. A = ImmutableMatrix([[1, 0], [0, 1]])
  450. B = ImmutableMatrix([1, 2])
  451. sol, params = A.gauss_jordan_solve(B)
  452. assert sol == ImmutableMatrix([1, 2])
  453. assert params == ImmutableMatrix(0, 1, [])
  454. assert sol.__class__ == ImmutableDenseMatrix
  455. assert params.__class__ == ImmutableDenseMatrix
  456. # Test placement of free variables
  457. A = Matrix([[1, 0, 0, 0], [0, 0, 0, 1]])
  458. b = Matrix([1, 1])
  459. sol, params = A.gauss_jordan_solve(b)
  460. w = {}
  461. for s in sol.atoms(Symbol):
  462. w[s.name] = s
  463. assert sol == Matrix([[1], [w['tau0']], [w['tau1']], [1]])
  464. assert params == Matrix([[w['tau0']], [w['tau1']]])
  465. def test_linsolve_underdetermined_AND_gauss_jordan_solve():
  466. #Test placement of free variables as per issue 19815
  467. A = Matrix([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  468. [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
  469. [0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
  470. [0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
  471. [0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0],
  472. [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0],
  473. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
  474. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]])
  475. B = Matrix([1, 2, 1, 1, 1, 1, 1, 2])
  476. sol, params = A.gauss_jordan_solve(B)
  477. w = {}
  478. for s in sol.atoms(Symbol):
  479. w[s.name] = s
  480. assert params == Matrix([[w['tau0']], [w['tau1']], [w['tau2']],
  481. [w['tau3']], [w['tau4']], [w['tau5']]])
  482. assert sol == Matrix([[1 - 1*w['tau2']],
  483. [w['tau2']],
  484. [1 - 1*w['tau0'] + w['tau1']],
  485. [w['tau0']],
  486. [w['tau3'] + w['tau4']],
  487. [-1*w['tau3'] - 1*w['tau4'] - 1*w['tau1']],
  488. [1 - 1*w['tau2']],
  489. [w['tau1']],
  490. [w['tau2']],
  491. [w['tau3']],
  492. [w['tau4']],
  493. [1 - 1*w['tau5']],
  494. [w['tau5']],
  495. [1]])
  496. from sympy.abc import j,f
  497. # https://github.com/sympy/sympy/issues/20046
  498. A = Matrix([
  499. [1, 1, 1, 1, 1, 1, 1, 1, 1],
  500. [0, -1, 0, -1, 0, -1, 0, -1, -j],
  501. [0, 0, 0, 0, 1, 1, 1, 1, f]
  502. ])
  503. sol_1=Matrix(list(linsolve(A))[0])
  504. tau0, tau1, tau2, tau3, tau4 = symbols('tau:5')
  505. assert sol_1 == Matrix([[-f - j - tau0 + tau2 + tau4 + 1],
  506. [j - tau1 - tau2 - tau4],
  507. [tau0],
  508. [tau1],
  509. [f - tau2 - tau3 - tau4],
  510. [tau2],
  511. [tau3],
  512. [tau4]])
  513. # https://github.com/sympy/sympy/issues/19815
  514. sol_2 = A[:, : -1 ] * sol_1 - A[:, -1 ]
  515. assert sol_2 == Matrix([[0], [0], [0]])
  516. @pytest.mark.parametrize("det_method", ["bird", "laplace"])
  517. @pytest.mark.parametrize("M, rhs", [
  518. (Matrix([[2, 3, 5], [3, 6, 2], [8, 3, 6]]), Matrix(3, 1, [3, 7, 5])),
  519. (Matrix([[2, 3, 5], [3, 6, 2], [8, 3, 6]]),
  520. Matrix([[1, 2], [3, 4], [5, 6]])),
  521. (Matrix(2, 2, symbols("a:4")), Matrix(2, 1, symbols("b:2"))),
  522. ])
  523. def test_cramer_solve(det_method, M, rhs):
  524. assert simplify(M.cramer_solve(rhs, det_method=det_method) - M.LUsolve(rhs)
  525. ) == Matrix.zeros(M.rows, rhs.cols)
  526. @pytest.mark.parametrize("det_method, error", [
  527. ("bird", DMShapeError), (_det_laplace, NonSquareMatrixError)])
  528. def test_cramer_solve_errors(det_method, error):
  529. # Non-square matrix
  530. A = Matrix([[0, -1, 2], [5, 10, 7]])
  531. b = Matrix([-2, 15])
  532. raises(error, lambda: A.cramer_solve(b, det_method=det_method))
  533. def test_solve():
  534. A = Matrix([[1,2], [2,4]])
  535. b = Matrix([[3], [4]])
  536. raises(ValueError, lambda: A.solve(b)) #no solution
  537. b = Matrix([[ 4], [8]])
  538. raises(ValueError, lambda: A.solve(b)) #infinite solution