solvers.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942
  1. from sympy.core.function import expand_mul
  2. from sympy.core.symbol import Dummy, uniquely_named_symbol, symbols
  3. from sympy.utilities.iterables import numbered_symbols
  4. from .exceptions import ShapeError, NonSquareMatrixError, NonInvertibleMatrixError
  5. from .eigen import _fuzzy_positive_definite
  6. from .utilities import _get_intermediate_simp, _iszero
  7. def _diagonal_solve(M, rhs):
  8. """Solves ``Ax = B`` efficiently, where A is a diagonal Matrix,
  9. with non-zero diagonal entries.
  10. Examples
  11. ========
  12. >>> from sympy import Matrix, eye
  13. >>> A = eye(2)*2
  14. >>> B = Matrix([[1, 2], [3, 4]])
  15. >>> A.diagonal_solve(B) == B/2
  16. True
  17. See Also
  18. ========
  19. sympy.matrices.dense.DenseMatrix.lower_triangular_solve
  20. sympy.matrices.dense.DenseMatrix.upper_triangular_solve
  21. gauss_jordan_solve
  22. cholesky_solve
  23. LDLsolve
  24. LUsolve
  25. QRsolve
  26. pinv_solve
  27. cramer_solve
  28. """
  29. if not M.is_diagonal():
  30. raise TypeError("Matrix should be diagonal")
  31. if rhs.rows != M.rows:
  32. raise TypeError("Size mismatch")
  33. return M._new(
  34. rhs.rows, rhs.cols, lambda i, j: rhs[i, j] / M[i, i])
  35. def _lower_triangular_solve(M, rhs):
  36. """Solves ``Ax = B``, where A is a lower triangular matrix.
  37. See Also
  38. ========
  39. upper_triangular_solve
  40. gauss_jordan_solve
  41. cholesky_solve
  42. diagonal_solve
  43. LDLsolve
  44. LUsolve
  45. QRsolve
  46. pinv_solve
  47. cramer_solve
  48. """
  49. from .dense import MutableDenseMatrix
  50. if not M.is_square:
  51. raise NonSquareMatrixError("Matrix must be square.")
  52. if rhs.rows != M.rows:
  53. raise ShapeError("Matrices size mismatch.")
  54. if not M.is_lower:
  55. raise ValueError("Matrix must be lower triangular.")
  56. dps = _get_intermediate_simp()
  57. X = MutableDenseMatrix.zeros(M.rows, rhs.cols)
  58. for j in range(rhs.cols):
  59. for i in range(M.rows):
  60. if M[i, i] == 0:
  61. raise TypeError("Matrix must be non-singular.")
  62. X[i, j] = dps((rhs[i, j] - sum(M[i, k]*X[k, j]
  63. for k in range(i))) / M[i, i])
  64. return M._new(X)
  65. def _lower_triangular_solve_sparse(M, rhs):
  66. """Solves ``Ax = B``, where A is a lower triangular matrix.
  67. See Also
  68. ========
  69. upper_triangular_solve
  70. gauss_jordan_solve
  71. cholesky_solve
  72. diagonal_solve
  73. LDLsolve
  74. LUsolve
  75. QRsolve
  76. pinv_solve
  77. cramer_solve
  78. """
  79. if not M.is_square:
  80. raise NonSquareMatrixError("Matrix must be square.")
  81. if rhs.rows != M.rows:
  82. raise ShapeError("Matrices size mismatch.")
  83. if not M.is_lower:
  84. raise ValueError("Matrix must be lower triangular.")
  85. dps = _get_intermediate_simp()
  86. rows = [[] for i in range(M.rows)]
  87. for i, j, v in M.row_list():
  88. if i > j:
  89. rows[i].append((j, v))
  90. X = rhs.as_mutable()
  91. for j in range(rhs.cols):
  92. for i in range(rhs.rows):
  93. for u, v in rows[i]:
  94. X[i, j] -= v*X[u, j]
  95. X[i, j] = dps(X[i, j] / M[i, i])
  96. return M._new(X)
  97. def _upper_triangular_solve(M, rhs):
  98. """Solves ``Ax = B``, where A is an upper triangular matrix.
  99. See Also
  100. ========
  101. lower_triangular_solve
  102. gauss_jordan_solve
  103. cholesky_solve
  104. diagonal_solve
  105. LDLsolve
  106. LUsolve
  107. QRsolve
  108. pinv_solve
  109. cramer_solve
  110. """
  111. from .dense import MutableDenseMatrix
  112. if not M.is_square:
  113. raise NonSquareMatrixError("Matrix must be square.")
  114. if rhs.rows != M.rows:
  115. raise ShapeError("Matrix size mismatch.")
  116. if not M.is_upper:
  117. raise TypeError("Matrix is not upper triangular.")
  118. dps = _get_intermediate_simp()
  119. X = MutableDenseMatrix.zeros(M.rows, rhs.cols)
  120. for j in range(rhs.cols):
  121. for i in reversed(range(M.rows)):
  122. if M[i, i] == 0:
  123. raise ValueError("Matrix must be non-singular.")
  124. X[i, j] = dps((rhs[i, j] - sum(M[i, k]*X[k, j]
  125. for k in range(i + 1, M.rows))) / M[i, i])
  126. return M._new(X)
  127. def _upper_triangular_solve_sparse(M, rhs):
  128. """Solves ``Ax = B``, where A is an upper triangular matrix.
  129. See Also
  130. ========
  131. lower_triangular_solve
  132. gauss_jordan_solve
  133. cholesky_solve
  134. diagonal_solve
  135. LDLsolve
  136. LUsolve
  137. QRsolve
  138. pinv_solve
  139. cramer_solve
  140. """
  141. if not M.is_square:
  142. raise NonSquareMatrixError("Matrix must be square.")
  143. if rhs.rows != M.rows:
  144. raise ShapeError("Matrix size mismatch.")
  145. if not M.is_upper:
  146. raise TypeError("Matrix is not upper triangular.")
  147. dps = _get_intermediate_simp()
  148. rows = [[] for i in range(M.rows)]
  149. for i, j, v in M.row_list():
  150. if i < j:
  151. rows[i].append((j, v))
  152. X = rhs.as_mutable()
  153. for j in range(rhs.cols):
  154. for i in reversed(range(rhs.rows)):
  155. for u, v in reversed(rows[i]):
  156. X[i, j] -= v*X[u, j]
  157. X[i, j] = dps(X[i, j] / M[i, i])
  158. return M._new(X)
  159. def _cholesky_solve(M, rhs):
  160. """Solves ``Ax = B`` using Cholesky decomposition,
  161. for a general square non-singular matrix.
  162. For a non-square matrix with rows > cols,
  163. the least squares solution is returned.
  164. See Also
  165. ========
  166. sympy.matrices.dense.DenseMatrix.lower_triangular_solve
  167. sympy.matrices.dense.DenseMatrix.upper_triangular_solve
  168. gauss_jordan_solve
  169. diagonal_solve
  170. LDLsolve
  171. LUsolve
  172. QRsolve
  173. pinv_solve
  174. cramer_solve
  175. """
  176. if M.rows < M.cols:
  177. raise NotImplementedError(
  178. 'Under-determined System. Try M.gauss_jordan_solve(rhs)')
  179. hermitian = True
  180. reform = False
  181. if M.is_symmetric():
  182. hermitian = False
  183. elif not M.is_hermitian:
  184. reform = True
  185. if reform or _fuzzy_positive_definite(M) is False:
  186. H = M.H
  187. M = H.multiply(M)
  188. rhs = H.multiply(rhs)
  189. hermitian = not M.is_symmetric()
  190. L = M.cholesky(hermitian=hermitian)
  191. Y = L.lower_triangular_solve(rhs)
  192. if hermitian:
  193. return (L.H).upper_triangular_solve(Y)
  194. else:
  195. return (L.T).upper_triangular_solve(Y)
  196. def _LDLsolve(M, rhs):
  197. """Solves ``Ax = B`` using LDL decomposition,
  198. for a general square and non-singular matrix.
  199. For a non-square matrix with rows > cols,
  200. the least squares solution is returned.
  201. Examples
  202. ========
  203. >>> from sympy import Matrix, eye
  204. >>> A = eye(2)*2
  205. >>> B = Matrix([[1, 2], [3, 4]])
  206. >>> A.LDLsolve(B) == B/2
  207. True
  208. See Also
  209. ========
  210. sympy.matrices.dense.DenseMatrix.LDLdecomposition
  211. sympy.matrices.dense.DenseMatrix.lower_triangular_solve
  212. sympy.matrices.dense.DenseMatrix.upper_triangular_solve
  213. gauss_jordan_solve
  214. cholesky_solve
  215. diagonal_solve
  216. LUsolve
  217. QRsolve
  218. pinv_solve
  219. cramer_solve
  220. """
  221. if M.rows < M.cols:
  222. raise NotImplementedError(
  223. 'Under-determined System. Try M.gauss_jordan_solve(rhs)')
  224. hermitian = True
  225. reform = False
  226. if M.is_symmetric():
  227. hermitian = False
  228. elif not M.is_hermitian:
  229. reform = True
  230. if reform or _fuzzy_positive_definite(M) is False:
  231. H = M.H
  232. M = H.multiply(M)
  233. rhs = H.multiply(rhs)
  234. hermitian = not M.is_symmetric()
  235. L, D = M.LDLdecomposition(hermitian=hermitian)
  236. Y = L.lower_triangular_solve(rhs)
  237. Z = D.diagonal_solve(Y)
  238. if hermitian:
  239. return (L.H).upper_triangular_solve(Z)
  240. else:
  241. return (L.T).upper_triangular_solve(Z)
  242. def _LUsolve(M, rhs, iszerofunc=_iszero):
  243. """Solve the linear system ``Ax = rhs`` for ``x`` where ``A = M``.
  244. This is for symbolic matrices, for real or complex ones use
  245. mpmath.lu_solve or mpmath.qr_solve.
  246. See Also
  247. ========
  248. sympy.matrices.dense.DenseMatrix.lower_triangular_solve
  249. sympy.matrices.dense.DenseMatrix.upper_triangular_solve
  250. gauss_jordan_solve
  251. cholesky_solve
  252. diagonal_solve
  253. LDLsolve
  254. QRsolve
  255. pinv_solve
  256. LUdecomposition
  257. cramer_solve
  258. """
  259. if rhs.rows != M.rows:
  260. raise ShapeError(
  261. "``M`` and ``rhs`` must have the same number of rows.")
  262. m = M.rows
  263. n = M.cols
  264. if m < n:
  265. raise NotImplementedError("Underdetermined systems not supported.")
  266. try:
  267. A, perm = M.LUdecomposition_Simple(
  268. iszerofunc=iszerofunc, rankcheck=True)
  269. except ValueError:
  270. raise NonInvertibleMatrixError("Matrix det == 0; not invertible.")
  271. dps = _get_intermediate_simp()
  272. b = rhs.permute_rows(perm).as_mutable()
  273. # forward substitution, all diag entries are scaled to 1
  274. for i in range(m):
  275. for j in range(min(i, n)):
  276. scale = A[i, j]
  277. b.zip_row_op(i, j, lambda x, y: dps(x - scale * y))
  278. # consistency check for overdetermined systems
  279. if m > n:
  280. for i in range(n, m):
  281. for j in range(b.cols):
  282. if not iszerofunc(b[i, j]):
  283. raise ValueError("The system is inconsistent.")
  284. b = b[0:n, :] # truncate zero rows if consistent
  285. # backward substitution
  286. for i in range(n - 1, -1, -1):
  287. for j in range(i + 1, n):
  288. scale = A[i, j]
  289. b.zip_row_op(i, j, lambda x, y: dps(x - scale * y))
  290. scale = A[i, i]
  291. b.row_op(i, lambda x, _: dps(scale**-1 * x))
  292. return rhs.__class__(b)
  293. def _QRsolve(M, b):
  294. """Solve the linear system ``Ax = b``.
  295. ``M`` is the matrix ``A``, the method argument is the vector
  296. ``b``. The method returns the solution vector ``x``. If ``b`` is a
  297. matrix, the system is solved for each column of ``b`` and the
  298. return value is a matrix of the same shape as ``b``.
  299. This method is slower (approximately by a factor of 2) but
  300. more stable for floating-point arithmetic than the LUsolve method.
  301. However, LUsolve usually uses an exact arithmetic, so you do not need
  302. to use QRsolve.
  303. This is mainly for educational purposes and symbolic matrices, for real
  304. (or complex) matrices use mpmath.qr_solve.
  305. See Also
  306. ========
  307. sympy.matrices.dense.DenseMatrix.lower_triangular_solve
  308. sympy.matrices.dense.DenseMatrix.upper_triangular_solve
  309. gauss_jordan_solve
  310. cholesky_solve
  311. diagonal_solve
  312. LDLsolve
  313. LUsolve
  314. pinv_solve
  315. QRdecomposition
  316. cramer_solve
  317. """
  318. dps = _get_intermediate_simp(expand_mul, expand_mul)
  319. Q, R = M.QRdecomposition()
  320. y = Q.T * b
  321. # back substitution to solve R*x = y:
  322. # We build up the result "backwards" in the vector 'x' and reverse it
  323. # only in the end.
  324. x = []
  325. n = R.rows
  326. for j in range(n - 1, -1, -1):
  327. tmp = y[j, :]
  328. for k in range(j + 1, n):
  329. tmp -= R[j, k] * x[n - 1 - k]
  330. tmp = dps(tmp)
  331. x.append(tmp / R[j, j])
  332. return M.vstack(*x[::-1])
  333. def _gauss_jordan_solve(M, B, freevar=False):
  334. """
  335. Solves ``Ax = B`` using Gauss Jordan elimination.
  336. There may be zero, one, or infinite solutions. If one solution
  337. exists, it will be returned. If infinite solutions exist, it will
  338. be returned parametrically. If no solutions exist, It will throw
  339. ValueError.
  340. Parameters
  341. ==========
  342. B : Matrix
  343. The right hand side of the equation to be solved for. Must have
  344. the same number of rows as matrix A.
  345. freevar : boolean, optional
  346. Flag, when set to `True` will return the indices of the free
  347. variables in the solutions (column Matrix), for a system that is
  348. undetermined (e.g. A has more columns than rows), for which
  349. infinite solutions are possible, in terms of arbitrary
  350. values of free variables. Default `False`.
  351. Returns
  352. =======
  353. x : Matrix
  354. The matrix that will satisfy ``Ax = B``. Will have as many rows as
  355. matrix A has columns, and as many columns as matrix B.
  356. params : Matrix
  357. If the system is underdetermined (e.g. A has more columns than
  358. rows), infinite solutions are possible, in terms of arbitrary
  359. parameters. These arbitrary parameters are returned as params
  360. Matrix.
  361. free_var_index : List, optional
  362. If the system is underdetermined (e.g. A has more columns than
  363. rows), infinite solutions are possible, in terms of arbitrary
  364. values of free variables. Then the indices of the free variables
  365. in the solutions (column Matrix) are returned by free_var_index,
  366. if the flag `freevar` is set to `True`.
  367. Examples
  368. ========
  369. >>> from sympy import Matrix
  370. >>> A = Matrix([[1, 2, 1, 1], [1, 2, 2, -1], [2, 4, 0, 6]])
  371. >>> B = Matrix([7, 12, 4])
  372. >>> sol, params = A.gauss_jordan_solve(B)
  373. >>> sol
  374. Matrix([
  375. [-2*tau0 - 3*tau1 + 2],
  376. [ tau0],
  377. [ 2*tau1 + 5],
  378. [ tau1]])
  379. >>> params
  380. Matrix([
  381. [tau0],
  382. [tau1]])
  383. >>> taus_zeroes = { tau:0 for tau in params }
  384. >>> sol_unique = sol.xreplace(taus_zeroes)
  385. >>> sol_unique
  386. Matrix([
  387. [2],
  388. [0],
  389. [5],
  390. [0]])
  391. >>> A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
  392. >>> B = Matrix([3, 6, 9])
  393. >>> sol, params = A.gauss_jordan_solve(B)
  394. >>> sol
  395. Matrix([
  396. [-1],
  397. [ 2],
  398. [ 0]])
  399. >>> params
  400. Matrix(0, 1, [])
  401. >>> A = Matrix([[2, -7], [-1, 4]])
  402. >>> B = Matrix([[-21, 3], [12, -2]])
  403. >>> sol, params = A.gauss_jordan_solve(B)
  404. >>> sol
  405. Matrix([
  406. [0, -2],
  407. [3, -1]])
  408. >>> params
  409. Matrix(0, 2, [])
  410. >>> from sympy import Matrix
  411. >>> A = Matrix([[1, 2, 1, 1], [1, 2, 2, -1], [2, 4, 0, 6]])
  412. >>> B = Matrix([7, 12, 4])
  413. >>> sol, params, freevars = A.gauss_jordan_solve(B, freevar=True)
  414. >>> sol
  415. Matrix([
  416. [-2*tau0 - 3*tau1 + 2],
  417. [ tau0],
  418. [ 2*tau1 + 5],
  419. [ tau1]])
  420. >>> params
  421. Matrix([
  422. [tau0],
  423. [tau1]])
  424. >>> freevars
  425. [1, 3]
  426. See Also
  427. ========
  428. sympy.matrices.dense.DenseMatrix.lower_triangular_solve
  429. sympy.matrices.dense.DenseMatrix.upper_triangular_solve
  430. cholesky_solve
  431. diagonal_solve
  432. LDLsolve
  433. LUsolve
  434. QRsolve
  435. pinv
  436. References
  437. ==========
  438. .. [1] https://en.wikipedia.org/wiki/Gaussian_elimination
  439. """
  440. from sympy.matrices import Matrix, zeros
  441. cls = M.__class__
  442. aug = M.hstack(M.copy(), B.copy())
  443. B_cols = B.cols
  444. row, col = aug[:, :-B_cols].shape
  445. # solve by reduced row echelon form
  446. A, pivots = aug.rref(simplify=True)
  447. A, v = A[:, :-B_cols], A[:, -B_cols:]
  448. pivots = list(filter(lambda p: p < col, pivots))
  449. rank = len(pivots)
  450. # Get index of free symbols (free parameters)
  451. # non-pivots columns are free variables
  452. free_var_index = [c for c in range(A.cols) if c not in pivots]
  453. # Bring to block form
  454. permutation = Matrix(pivots + free_var_index).T
  455. # check for existence of solutions
  456. # rank of aug Matrix should be equal to rank of coefficient matrix
  457. if not v[rank:, :].is_zero_matrix:
  458. raise ValueError("Linear system has no solution")
  459. # Free parameters
  460. # what are current unnumbered free symbol names?
  461. name = uniquely_named_symbol('tau', [aug],
  462. compare=lambda i: str(i).rstrip('1234567890'),
  463. modify=lambda s: '_' + s).name
  464. gen = numbered_symbols(name)
  465. tau = Matrix([next(gen) for k in range((col - rank)*B_cols)]).reshape(
  466. col - rank, B_cols)
  467. # Full parametric solution
  468. V = A[:rank, free_var_index]
  469. vt = v[:rank, :]
  470. free_sol = tau.vstack(vt - V * tau, tau)
  471. # Undo permutation
  472. sol = zeros(col, B_cols)
  473. for k in range(col):
  474. sol[permutation[k], :] = free_sol[k,:]
  475. sol, tau = cls(sol), cls(tau)
  476. if freevar:
  477. return sol, tau, free_var_index
  478. else:
  479. return sol, tau
  480. def _pinv_solve(M, B, arbitrary_matrix=None):
  481. """Solve ``Ax = B`` using the Moore-Penrose pseudoinverse.
  482. There may be zero, one, or infinite solutions. If one solution
  483. exists, it will be returned. If infinite solutions exist, one will
  484. be returned based on the value of arbitrary_matrix. If no solutions
  485. exist, the least-squares solution is returned.
  486. Parameters
  487. ==========
  488. B : Matrix
  489. The right hand side of the equation to be solved for. Must have
  490. the same number of rows as matrix A.
  491. arbitrary_matrix : Matrix
  492. If the system is underdetermined (e.g. A has more columns than
  493. rows), infinite solutions are possible, in terms of an arbitrary
  494. matrix. This parameter may be set to a specific matrix to use
  495. for that purpose; if so, it must be the same shape as x, with as
  496. many rows as matrix A has columns, and as many columns as matrix
  497. B. If left as None, an appropriate matrix containing dummy
  498. symbols in the form of ``wn_m`` will be used, with n and m being
  499. row and column position of each symbol.
  500. Returns
  501. =======
  502. x : Matrix
  503. The matrix that will satisfy ``Ax = B``. Will have as many rows as
  504. matrix A has columns, and as many columns as matrix B.
  505. Examples
  506. ========
  507. >>> from sympy import Matrix
  508. >>> A = Matrix([[1, 2, 3], [4, 5, 6]])
  509. >>> B = Matrix([7, 8])
  510. >>> A.pinv_solve(B)
  511. Matrix([
  512. [ _w0_0/6 - _w1_0/3 + _w2_0/6 - 55/18],
  513. [-_w0_0/3 + 2*_w1_0/3 - _w2_0/3 + 1/9],
  514. [ _w0_0/6 - _w1_0/3 + _w2_0/6 + 59/18]])
  515. >>> A.pinv_solve(B, arbitrary_matrix=Matrix([0, 0, 0]))
  516. Matrix([
  517. [-55/18],
  518. [ 1/9],
  519. [ 59/18]])
  520. See Also
  521. ========
  522. sympy.matrices.dense.DenseMatrix.lower_triangular_solve
  523. sympy.matrices.dense.DenseMatrix.upper_triangular_solve
  524. gauss_jordan_solve
  525. cholesky_solve
  526. diagonal_solve
  527. LDLsolve
  528. LUsolve
  529. QRsolve
  530. pinv
  531. Notes
  532. =====
  533. This may return either exact solutions or least squares solutions.
  534. To determine which, check ``A * A.pinv() * B == B``. It will be
  535. True if exact solutions exist, and False if only a least-squares
  536. solution exists. Be aware that the left hand side of that equation
  537. may need to be simplified to correctly compare to the right hand
  538. side.
  539. References
  540. ==========
  541. .. [1] https://en.wikipedia.org/wiki/Moore-Penrose_pseudoinverse#Obtaining_all_solutions_of_a_linear_system
  542. """
  543. from sympy.matrices import eye
  544. A = M
  545. A_pinv = M.pinv()
  546. if arbitrary_matrix is None:
  547. rows, cols = A.cols, B.cols
  548. w = symbols('w:{}_:{}'.format(rows, cols), cls=Dummy)
  549. arbitrary_matrix = M.__class__(cols, rows, w).T
  550. return A_pinv.multiply(B) + (eye(A.cols) -
  551. A_pinv.multiply(A)).multiply(arbitrary_matrix)
  552. def _cramer_solve(M, rhs, det_method="laplace"):
  553. """Solves system of linear equations using Cramer's rule.
  554. This method is relatively inefficient compared to other methods.
  555. However it only uses a single division, assuming a division-free determinant
  556. method is provided. This is helpful to minimize the chance of divide-by-zero
  557. cases in symbolic solutions to linear systems.
  558. Parameters
  559. ==========
  560. M : Matrix
  561. The matrix representing the left hand side of the equation.
  562. rhs : Matrix
  563. The matrix representing the right hand side of the equation.
  564. det_method : str or callable
  565. The method to use to calculate the determinant of the matrix.
  566. The default is ``'laplace'``. If a callable is passed, it should take a
  567. single argument, the matrix, and return the determinant of the matrix.
  568. Returns
  569. =======
  570. x : Matrix
  571. The matrix that will satisfy ``Ax = B``. Will have as many rows as
  572. matrix A has columns, and as many columns as matrix B.
  573. Examples
  574. ========
  575. >>> from sympy import Matrix
  576. >>> A = Matrix([[0, -6, 1], [0, -6, -1], [-5, -2, 3]])
  577. >>> B = Matrix([[-30, -9], [-18, -27], [-26, 46]])
  578. >>> x = A.cramer_solve(B)
  579. >>> x
  580. Matrix([
  581. [ 0, -5],
  582. [ 4, 3],
  583. [-6, 9]])
  584. References
  585. ==========
  586. .. [1] https://en.wikipedia.org/wiki/Cramer%27s_rule#Explicit_formulas_for_small_systems
  587. """
  588. from .dense import zeros
  589. def entry(i, j):
  590. return rhs[i, sol] if j == col else M[i, j]
  591. if det_method == "bird":
  592. from .determinant import _det_bird
  593. det = _det_bird
  594. elif det_method == "laplace":
  595. from .determinant import _det_laplace
  596. det = _det_laplace
  597. elif isinstance(det_method, str):
  598. det = lambda matrix: matrix.det(method=det_method)
  599. else:
  600. det = det_method
  601. det_M = det(M)
  602. x = zeros(*rhs.shape)
  603. for sol in range(rhs.shape[1]):
  604. for col in range(rhs.shape[0]):
  605. x[col, sol] = det(M.__class__(*M.shape, entry)) / det_M
  606. return M.__class__(x)
  607. def _solve(M, rhs, method='GJ'):
  608. """Solves linear equation where the unique solution exists.
  609. Parameters
  610. ==========
  611. rhs : Matrix
  612. Vector representing the right hand side of the linear equation.
  613. method : string, optional
  614. If set to ``'GJ'`` or ``'GE'``, the Gauss-Jordan elimination will be
  615. used, which is implemented in the routine ``gauss_jordan_solve``.
  616. If set to ``'LU'``, ``LUsolve`` routine will be used.
  617. If set to ``'QR'``, ``QRsolve`` routine will be used.
  618. If set to ``'PINV'``, ``pinv_solve`` routine will be used.
  619. If set to ``'CRAMER'``, ``cramer_solve`` routine will be used.
  620. It also supports the methods available for special linear systems
  621. For positive definite systems:
  622. If set to ``'CH'``, ``cholesky_solve`` routine will be used.
  623. If set to ``'LDL'``, ``LDLsolve`` routine will be used.
  624. To use a different method and to compute the solution via the
  625. inverse, use a method defined in the .inv() docstring.
  626. Returns
  627. =======
  628. solutions : Matrix
  629. Vector representing the solution.
  630. Raises
  631. ======
  632. ValueError
  633. If there is not a unique solution then a ``ValueError`` will be
  634. raised.
  635. If ``M`` is not square, a ``ValueError`` and a different routine
  636. for solving the system will be suggested.
  637. """
  638. if method in ('GJ', 'GE'):
  639. try:
  640. soln, param = M.gauss_jordan_solve(rhs)
  641. if param:
  642. raise NonInvertibleMatrixError("Matrix det == 0; not invertible. "
  643. "Try ``M.gauss_jordan_solve(rhs)`` to obtain a parametric solution.")
  644. except ValueError:
  645. raise NonInvertibleMatrixError("Matrix det == 0; not invertible.")
  646. return soln
  647. elif method == 'LU':
  648. return M.LUsolve(rhs)
  649. elif method == 'CH':
  650. return M.cholesky_solve(rhs)
  651. elif method == 'QR':
  652. return M.QRsolve(rhs)
  653. elif method == 'LDL':
  654. return M.LDLsolve(rhs)
  655. elif method == 'PINV':
  656. return M.pinv_solve(rhs)
  657. elif method == 'CRAMER':
  658. return M.cramer_solve(rhs)
  659. else:
  660. return M.inv(method=method).multiply(rhs)
  661. def _solve_least_squares(M, rhs, method='CH'):
  662. """Return the least-square fit to the data.
  663. Parameters
  664. ==========
  665. rhs : Matrix
  666. Vector representing the right hand side of the linear equation.
  667. method : string or boolean, optional
  668. If set to ``'CH'``, ``cholesky_solve`` routine will be used.
  669. If set to ``'LDL'``, ``LDLsolve`` routine will be used.
  670. If set to ``'QR'``, ``QRsolve`` routine will be used.
  671. If set to ``'PINV'``, ``pinv_solve`` routine will be used.
  672. Otherwise, the conjugate of ``M`` will be used to create a system
  673. of equations that is passed to ``solve`` along with the hint
  674. defined by ``method``.
  675. Returns
  676. =======
  677. solutions : Matrix
  678. Vector representing the solution.
  679. Examples
  680. ========
  681. >>> from sympy import Matrix, ones
  682. >>> A = Matrix([1, 2, 3])
  683. >>> B = Matrix([2, 3, 4])
  684. >>> S = Matrix(A.row_join(B))
  685. >>> S
  686. Matrix([
  687. [1, 2],
  688. [2, 3],
  689. [3, 4]])
  690. If each line of S represent coefficients of Ax + By
  691. and x and y are [2, 3] then S*xy is:
  692. >>> r = S*Matrix([2, 3]); r
  693. Matrix([
  694. [ 8],
  695. [13],
  696. [18]])
  697. But let's add 1 to the middle value and then solve for the
  698. least-squares value of xy:
  699. >>> xy = S.solve_least_squares(Matrix([8, 14, 18])); xy
  700. Matrix([
  701. [ 5/3],
  702. [10/3]])
  703. The error is given by S*xy - r:
  704. >>> S*xy - r
  705. Matrix([
  706. [1/3],
  707. [1/3],
  708. [1/3]])
  709. >>> _.norm().n(2)
  710. 0.58
  711. If a different xy is used, the norm will be higher:
  712. >>> xy += ones(2, 1)/10
  713. >>> (S*xy - r).norm().n(2)
  714. 1.5
  715. """
  716. if method == 'CH':
  717. return M.cholesky_solve(rhs)
  718. elif method == 'QR':
  719. return M.QRsolve(rhs)
  720. elif method == 'LDL':
  721. return M.LDLsolve(rhs)
  722. elif method == 'PINV':
  723. return M.pinv_solve(rhs)
  724. else:
  725. t = M.H
  726. return (t * M).solve(t * rhs, method=method)