blockmatrix.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975
  1. from sympy.assumptions.ask import (Q, ask)
  2. from sympy.core import Basic, Add, Mul, S
  3. from sympy.core.sympify import _sympify
  4. from sympy.functions.elementary.complexes import re, im
  5. from sympy.strategies import typed, exhaust, condition, do_one, unpack
  6. from sympy.strategies.traverse import bottom_up
  7. from sympy.utilities.iterables import is_sequence, sift
  8. from sympy.utilities.misc import filldedent
  9. from sympy.matrices import Matrix, ShapeError
  10. from sympy.matrices.exceptions import NonInvertibleMatrixError
  11. from sympy.matrices.expressions.determinant import det, Determinant
  12. from sympy.matrices.expressions.inverse import Inverse
  13. from sympy.matrices.expressions.matadd import MatAdd
  14. from sympy.matrices.expressions.matexpr import MatrixExpr, MatrixElement
  15. from sympy.matrices.expressions.matmul import MatMul
  16. from sympy.matrices.expressions.matpow import MatPow
  17. from sympy.matrices.expressions.slice import MatrixSlice
  18. from sympy.matrices.expressions.special import ZeroMatrix, Identity
  19. from sympy.matrices.expressions.trace import trace
  20. from sympy.matrices.expressions.transpose import Transpose, transpose
  21. class BlockMatrix(MatrixExpr):
  22. """A BlockMatrix is a Matrix comprised of other matrices.
  23. The submatrices are stored in a SymPy Matrix object but accessed as part of
  24. a Matrix Expression
  25. >>> from sympy import (MatrixSymbol, BlockMatrix, symbols,
  26. ... Identity, ZeroMatrix, block_collapse)
  27. >>> n,m,l = symbols('n m l')
  28. >>> X = MatrixSymbol('X', n, n)
  29. >>> Y = MatrixSymbol('Y', m, m)
  30. >>> Z = MatrixSymbol('Z', n, m)
  31. >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m,n), Y]])
  32. >>> print(B)
  33. Matrix([
  34. [X, Z],
  35. [0, Y]])
  36. >>> C = BlockMatrix([[Identity(n), Z]])
  37. >>> print(C)
  38. Matrix([[I, Z]])
  39. >>> print(block_collapse(C*B))
  40. Matrix([[X, Z + Z*Y]])
  41. Some matrices might be comprised of rows of blocks with
  42. the matrices in each row having the same height and the
  43. rows all having the same total number of columns but
  44. not having the same number of columns for each matrix
  45. in each row. In this case, the matrix is not a block
  46. matrix and should be instantiated by Matrix.
  47. >>> from sympy import ones, Matrix
  48. >>> dat = [
  49. ... [ones(3,2), ones(3,3)*2],
  50. ... [ones(2,3)*3, ones(2,2)*4]]
  51. ...
  52. >>> BlockMatrix(dat)
  53. Traceback (most recent call last):
  54. ...
  55. ValueError:
  56. Although this matrix is comprised of blocks, the blocks do not fill
  57. the matrix in a size-symmetric fashion. To create a full matrix from
  58. these arguments, pass them directly to Matrix.
  59. >>> Matrix(dat)
  60. Matrix([
  61. [1, 1, 2, 2, 2],
  62. [1, 1, 2, 2, 2],
  63. [1, 1, 2, 2, 2],
  64. [3, 3, 3, 4, 4],
  65. [3, 3, 3, 4, 4]])
  66. See Also
  67. ========
  68. sympy.matrices.matrixbase.MatrixBase.irregular
  69. """
  70. def __new__(cls, *args, **kwargs):
  71. from sympy.matrices.immutable import ImmutableDenseMatrix
  72. isMat = lambda i: getattr(i, 'is_Matrix', False)
  73. if len(args) != 1 or \
  74. not is_sequence(args[0]) or \
  75. len({isMat(r) for r in args[0]}) != 1:
  76. raise ValueError(filldedent('''
  77. expecting a sequence of 1 or more rows
  78. containing Matrices.'''))
  79. rows = args[0] if args else []
  80. if not isMat(rows):
  81. if rows and isMat(rows[0]):
  82. rows = [rows] # rows is not list of lists or []
  83. # regularity check
  84. # same number of matrices in each row
  85. blocky = ok = len({len(r) for r in rows}) == 1
  86. if ok:
  87. # same number of rows for each matrix in a row
  88. for r in rows:
  89. ok = len({i.rows for i in r}) == 1
  90. if not ok:
  91. break
  92. blocky = ok
  93. if ok:
  94. # same number of cols for each matrix in each col
  95. for c in range(len(rows[0])):
  96. ok = len({rows[i][c].cols
  97. for i in range(len(rows))}) == 1
  98. if not ok:
  99. break
  100. if not ok:
  101. # same total cols in each row
  102. ok = len({
  103. sum(i.cols for i in r) for r in rows}) == 1
  104. if blocky and ok:
  105. raise ValueError(filldedent('''
  106. Although this matrix is comprised of blocks,
  107. the blocks do not fill the matrix in a
  108. size-symmetric fashion. To create a full matrix
  109. from these arguments, pass them directly to
  110. Matrix.'''))
  111. raise ValueError(filldedent('''
  112. When there are not the same number of rows in each
  113. row's matrices or there are not the same number of
  114. total columns in each row, the matrix is not a
  115. block matrix. If this matrix is known to consist of
  116. blocks fully filling a 2-D space then see
  117. Matrix.irregular.'''))
  118. mat = ImmutableDenseMatrix(rows, evaluate=False)
  119. obj = Basic.__new__(cls, mat)
  120. return obj
  121. @property
  122. def shape(self):
  123. numrows = numcols = 0
  124. M = self.blocks
  125. for i in range(M.shape[0]):
  126. numrows += M[i, 0].shape[0]
  127. for i in range(M.shape[1]):
  128. numcols += M[0, i].shape[1]
  129. return (numrows, numcols)
  130. @property
  131. def blockshape(self):
  132. return self.blocks.shape
  133. @property
  134. def blocks(self):
  135. return self.args[0]
  136. @property
  137. def rowblocksizes(self):
  138. return [self.blocks[i, 0].rows for i in range(self.blockshape[0])]
  139. @property
  140. def colblocksizes(self):
  141. return [self.blocks[0, i].cols for i in range(self.blockshape[1])]
  142. def structurally_equal(self, other):
  143. return (isinstance(other, BlockMatrix)
  144. and self.shape == other.shape
  145. and self.blockshape == other.blockshape
  146. and self.rowblocksizes == other.rowblocksizes
  147. and self.colblocksizes == other.colblocksizes)
  148. def _blockmul(self, other):
  149. if (isinstance(other, BlockMatrix) and
  150. self.colblocksizes == other.rowblocksizes):
  151. return BlockMatrix(self.blocks*other.blocks)
  152. return self * other
  153. def _blockadd(self, other):
  154. if (isinstance(other, BlockMatrix)
  155. and self.structurally_equal(other)):
  156. return BlockMatrix(self.blocks + other.blocks)
  157. return self + other
  158. def _eval_transpose(self):
  159. # Flip all the individual matrices
  160. matrices = [transpose(matrix) for matrix in self.blocks]
  161. # Make a copy
  162. M = Matrix(self.blockshape[0], self.blockshape[1], matrices)
  163. # Transpose the block structure
  164. M = M.transpose()
  165. return BlockMatrix(M)
  166. def _eval_adjoint(self):
  167. return BlockMatrix(
  168. Matrix(self.blockshape[0], self.blockshape[1], self.blocks).adjoint()
  169. )
  170. def _eval_trace(self):
  171. if self.rowblocksizes == self.colblocksizes:
  172. blocks = [self.blocks[i, i] for i in range(self.blockshape[0])]
  173. return Add(*[trace(block) for block in blocks])
  174. def _eval_determinant(self):
  175. if self.blockshape == (1, 1):
  176. return det(self.blocks[0, 0])
  177. if self.blockshape == (2, 2):
  178. [[A, B],
  179. [C, D]] = self.blocks.tolist()
  180. if ask(Q.invertible(A)):
  181. return det(A)*det(D - C*A.I*B)
  182. elif ask(Q.invertible(D)):
  183. return det(D)*det(A - B*D.I*C)
  184. return Determinant(self)
  185. def _eval_as_real_imag(self):
  186. real_matrices = [re(matrix) for matrix in self.blocks]
  187. real_matrices = Matrix(self.blockshape[0], self.blockshape[1], real_matrices)
  188. im_matrices = [im(matrix) for matrix in self.blocks]
  189. im_matrices = Matrix(self.blockshape[0], self.blockshape[1], im_matrices)
  190. return (BlockMatrix(real_matrices), BlockMatrix(im_matrices))
  191. def _eval_derivative(self, x):
  192. return BlockMatrix(self.blocks.diff(x))
  193. def transpose(self):
  194. """Return transpose of matrix.
  195. Examples
  196. ========
  197. >>> from sympy import MatrixSymbol, BlockMatrix, ZeroMatrix
  198. >>> from sympy.abc import m, n
  199. >>> X = MatrixSymbol('X', n, n)
  200. >>> Y = MatrixSymbol('Y', m, m)
  201. >>> Z = MatrixSymbol('Z', n, m)
  202. >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m,n), Y]])
  203. >>> B.transpose()
  204. Matrix([
  205. [X.T, 0],
  206. [Z.T, Y.T]])
  207. >>> _.transpose()
  208. Matrix([
  209. [X, Z],
  210. [0, Y]])
  211. """
  212. return self._eval_transpose()
  213. def schur(self, mat = 'A', generalized = False):
  214. """Return the Schur Complement of the 2x2 BlockMatrix
  215. Parameters
  216. ==========
  217. mat : String, optional
  218. The matrix with respect to which the
  219. Schur Complement is calculated. 'A' is
  220. used by default
  221. generalized : bool, optional
  222. If True, returns the generalized Schur
  223. Component which uses Moore-Penrose Inverse
  224. Examples
  225. ========
  226. >>> from sympy import symbols, MatrixSymbol, BlockMatrix
  227. >>> m, n = symbols('m n')
  228. >>> A = MatrixSymbol('A', n, n)
  229. >>> B = MatrixSymbol('B', n, m)
  230. >>> C = MatrixSymbol('C', m, n)
  231. >>> D = MatrixSymbol('D', m, m)
  232. >>> X = BlockMatrix([[A, B], [C, D]])
  233. The default Schur Complement is evaluated with "A"
  234. >>> X.schur()
  235. -C*A**(-1)*B + D
  236. >>> X.schur('D')
  237. A - B*D**(-1)*C
  238. Schur complement with non-invertible matrices is not
  239. defined. Instead, the generalized Schur complement can
  240. be calculated which uses the Moore-Penrose Inverse. To
  241. achieve this, `generalized` must be set to `True`
  242. >>> X.schur('B', generalized=True)
  243. C - D*(B.T*B)**(-1)*B.T*A
  244. >>> X.schur('C', generalized=True)
  245. -A*(C.T*C)**(-1)*C.T*D + B
  246. Returns
  247. =======
  248. M : Matrix
  249. The Schur Complement Matrix
  250. Raises
  251. ======
  252. ShapeError
  253. If the block matrix is not a 2x2 matrix
  254. NonInvertibleMatrixError
  255. If given matrix is non-invertible
  256. References
  257. ==========
  258. .. [1] Wikipedia Article on Schur Component : https://en.wikipedia.org/wiki/Schur_complement
  259. See Also
  260. ========
  261. sympy.matrices.matrixbase.MatrixBase.pinv
  262. """
  263. if self.blockshape == (2, 2):
  264. [[A, B],
  265. [C, D]] = self.blocks.tolist()
  266. d={'A' : A, 'B' : B, 'C' : C, 'D' : D}
  267. try:
  268. inv = (d[mat].T*d[mat]).inv()*d[mat].T if generalized else d[mat].inv()
  269. if mat == 'A':
  270. return D - C * inv * B
  271. elif mat == 'B':
  272. return C - D * inv * A
  273. elif mat == 'C':
  274. return B - A * inv * D
  275. elif mat == 'D':
  276. return A - B * inv * C
  277. #For matrices where no sub-matrix is square
  278. return self
  279. except NonInvertibleMatrixError:
  280. raise NonInvertibleMatrixError('The given matrix is not invertible. Please set generalized=True \
  281. to compute the generalized Schur Complement which uses Moore-Penrose Inverse')
  282. else:
  283. raise ShapeError('Schur Complement can only be calculated for 2x2 block matrices')
  284. def LDUdecomposition(self):
  285. """Returns the Block LDU decomposition of
  286. a 2x2 Block Matrix
  287. Returns
  288. =======
  289. (L, D, U) : Matrices
  290. L : Lower Diagonal Matrix
  291. D : Diagonal Matrix
  292. U : Upper Diagonal Matrix
  293. Examples
  294. ========
  295. >>> from sympy import symbols, MatrixSymbol, BlockMatrix, block_collapse
  296. >>> m, n = symbols('m n')
  297. >>> A = MatrixSymbol('A', n, n)
  298. >>> B = MatrixSymbol('B', n, m)
  299. >>> C = MatrixSymbol('C', m, n)
  300. >>> D = MatrixSymbol('D', m, m)
  301. >>> X = BlockMatrix([[A, B], [C, D]])
  302. >>> L, D, U = X.LDUdecomposition()
  303. >>> block_collapse(L*D*U)
  304. Matrix([
  305. [A, B],
  306. [C, D]])
  307. Raises
  308. ======
  309. ShapeError
  310. If the block matrix is not a 2x2 matrix
  311. NonInvertibleMatrixError
  312. If the matrix "A" is non-invertible
  313. See Also
  314. ========
  315. sympy.matrices.expressions.blockmatrix.BlockMatrix.UDLdecomposition
  316. sympy.matrices.expressions.blockmatrix.BlockMatrix.LUdecomposition
  317. """
  318. if self.blockshape == (2,2):
  319. [[A, B],
  320. [C, D]] = self.blocks.tolist()
  321. try:
  322. AI = A.I
  323. except NonInvertibleMatrixError:
  324. raise NonInvertibleMatrixError('Block LDU decomposition cannot be calculated when\
  325. "A" is singular')
  326. Ip = Identity(B.shape[0])
  327. Iq = Identity(B.shape[1])
  328. Z = ZeroMatrix(*B.shape)
  329. L = BlockMatrix([[Ip, Z], [C*AI, Iq]])
  330. D = BlockDiagMatrix(A, self.schur())
  331. U = BlockMatrix([[Ip, AI*B],[Z.T, Iq]])
  332. return L, D, U
  333. else:
  334. raise ShapeError("Block LDU decomposition is supported only for 2x2 block matrices")
  335. def UDLdecomposition(self):
  336. """Returns the Block UDL decomposition of
  337. a 2x2 Block Matrix
  338. Returns
  339. =======
  340. (U, D, L) : Matrices
  341. U : Upper Diagonal Matrix
  342. D : Diagonal Matrix
  343. L : Lower Diagonal Matrix
  344. Examples
  345. ========
  346. >>> from sympy import symbols, MatrixSymbol, BlockMatrix, block_collapse
  347. >>> m, n = symbols('m n')
  348. >>> A = MatrixSymbol('A', n, n)
  349. >>> B = MatrixSymbol('B', n, m)
  350. >>> C = MatrixSymbol('C', m, n)
  351. >>> D = MatrixSymbol('D', m, m)
  352. >>> X = BlockMatrix([[A, B], [C, D]])
  353. >>> U, D, L = X.UDLdecomposition()
  354. >>> block_collapse(U*D*L)
  355. Matrix([
  356. [A, B],
  357. [C, D]])
  358. Raises
  359. ======
  360. ShapeError
  361. If the block matrix is not a 2x2 matrix
  362. NonInvertibleMatrixError
  363. If the matrix "D" is non-invertible
  364. See Also
  365. ========
  366. sympy.matrices.expressions.blockmatrix.BlockMatrix.LDUdecomposition
  367. sympy.matrices.expressions.blockmatrix.BlockMatrix.LUdecomposition
  368. """
  369. if self.blockshape == (2,2):
  370. [[A, B],
  371. [C, D]] = self.blocks.tolist()
  372. try:
  373. DI = D.I
  374. except NonInvertibleMatrixError:
  375. raise NonInvertibleMatrixError('Block UDL decomposition cannot be calculated when\
  376. "D" is singular')
  377. Ip = Identity(A.shape[0])
  378. Iq = Identity(B.shape[1])
  379. Z = ZeroMatrix(*B.shape)
  380. U = BlockMatrix([[Ip, B*DI], [Z.T, Iq]])
  381. D = BlockDiagMatrix(self.schur('D'), D)
  382. L = BlockMatrix([[Ip, Z],[DI*C, Iq]])
  383. return U, D, L
  384. else:
  385. raise ShapeError("Block UDL decomposition is supported only for 2x2 block matrices")
  386. def LUdecomposition(self):
  387. """Returns the Block LU decomposition of
  388. a 2x2 Block Matrix
  389. Returns
  390. =======
  391. (L, U) : Matrices
  392. L : Lower Diagonal Matrix
  393. U : Upper Diagonal Matrix
  394. Examples
  395. ========
  396. >>> from sympy import symbols, MatrixSymbol, BlockMatrix, block_collapse
  397. >>> m, n = symbols('m n')
  398. >>> A = MatrixSymbol('A', n, n)
  399. >>> B = MatrixSymbol('B', n, m)
  400. >>> C = MatrixSymbol('C', m, n)
  401. >>> D = MatrixSymbol('D', m, m)
  402. >>> X = BlockMatrix([[A, B], [C, D]])
  403. >>> L, U = X.LUdecomposition()
  404. >>> block_collapse(L*U)
  405. Matrix([
  406. [A, B],
  407. [C, D]])
  408. Raises
  409. ======
  410. ShapeError
  411. If the block matrix is not a 2x2 matrix
  412. NonInvertibleMatrixError
  413. If the matrix "A" is non-invertible
  414. See Also
  415. ========
  416. sympy.matrices.expressions.blockmatrix.BlockMatrix.UDLdecomposition
  417. sympy.matrices.expressions.blockmatrix.BlockMatrix.LDUdecomposition
  418. """
  419. if self.blockshape == (2,2):
  420. [[A, B],
  421. [C, D]] = self.blocks.tolist()
  422. try:
  423. A = A**S.Half
  424. AI = A.I
  425. except NonInvertibleMatrixError:
  426. raise NonInvertibleMatrixError('Block LU decomposition cannot be calculated when\
  427. "A" is singular')
  428. Z = ZeroMatrix(*B.shape)
  429. Q = self.schur()**S.Half
  430. L = BlockMatrix([[A, Z], [C*AI, Q]])
  431. U = BlockMatrix([[A, AI*B],[Z.T, Q]])
  432. return L, U
  433. else:
  434. raise ShapeError("Block LU decomposition is supported only for 2x2 block matrices")
  435. def _entry(self, i, j, **kwargs):
  436. # Find row entry
  437. orig_i, orig_j = i, j
  438. for row_block, numrows in enumerate(self.rowblocksizes):
  439. cmp = i < numrows
  440. if cmp == True:
  441. break
  442. elif cmp == False:
  443. i -= numrows
  444. elif row_block < self.blockshape[0] - 1:
  445. # Can't tell which block and it's not the last one, return unevaluated
  446. return MatrixElement(self, orig_i, orig_j)
  447. for col_block, numcols in enumerate(self.colblocksizes):
  448. cmp = j < numcols
  449. if cmp == True:
  450. break
  451. elif cmp == False:
  452. j -= numcols
  453. elif col_block < self.blockshape[1] - 1:
  454. return MatrixElement(self, orig_i, orig_j)
  455. return self.blocks[row_block, col_block][i, j]
  456. @property
  457. def is_Identity(self):
  458. if self.blockshape[0] != self.blockshape[1]:
  459. return False
  460. for i in range(self.blockshape[0]):
  461. for j in range(self.blockshape[1]):
  462. if i==j and not self.blocks[i, j].is_Identity:
  463. return False
  464. if i!=j and not self.blocks[i, j].is_ZeroMatrix:
  465. return False
  466. return True
  467. @property
  468. def is_structurally_symmetric(self):
  469. return self.rowblocksizes == self.colblocksizes
  470. def equals(self, other):
  471. if self == other:
  472. return True
  473. if (isinstance(other, BlockMatrix) and self.blocks == other.blocks):
  474. return True
  475. return super().equals(other)
  476. class BlockDiagMatrix(BlockMatrix):
  477. """A sparse matrix with block matrices along its diagonals
  478. Examples
  479. ========
  480. >>> from sympy import MatrixSymbol, BlockDiagMatrix, symbols
  481. >>> n, m, l = symbols('n m l')
  482. >>> X = MatrixSymbol('X', n, n)
  483. >>> Y = MatrixSymbol('Y', m, m)
  484. >>> BlockDiagMatrix(X, Y)
  485. Matrix([
  486. [X, 0],
  487. [0, Y]])
  488. Notes
  489. =====
  490. If you want to get the individual diagonal blocks, use
  491. :meth:`get_diag_blocks`.
  492. See Also
  493. ========
  494. sympy.matrices.dense.diag
  495. """
  496. def __new__(cls, *mats):
  497. return Basic.__new__(BlockDiagMatrix, *[_sympify(m) for m in mats])
  498. @property
  499. def diag(self):
  500. return self.args
  501. @property
  502. def blocks(self):
  503. from sympy.matrices.immutable import ImmutableDenseMatrix
  504. mats = self.args
  505. data = [[mats[i] if i == j else ZeroMatrix(mats[i].rows, mats[j].cols)
  506. for j in range(len(mats))]
  507. for i in range(len(mats))]
  508. return ImmutableDenseMatrix(data, evaluate=False)
  509. @property
  510. def shape(self):
  511. return (sum(block.rows for block in self.args),
  512. sum(block.cols for block in self.args))
  513. @property
  514. def blockshape(self):
  515. n = len(self.args)
  516. return (n, n)
  517. @property
  518. def rowblocksizes(self):
  519. return [block.rows for block in self.args]
  520. @property
  521. def colblocksizes(self):
  522. return [block.cols for block in self.args]
  523. def _all_square_blocks(self):
  524. """Returns true if all blocks are square"""
  525. return all(mat.is_square for mat in self.args)
  526. def _eval_determinant(self):
  527. if self._all_square_blocks():
  528. return Mul(*[det(mat) for mat in self.args])
  529. # At least one block is non-square. Since the entire matrix must be square we know there must
  530. # be at least two blocks in this matrix, in which case the entire matrix is necessarily rank-deficient
  531. return S.Zero
  532. def _eval_inverse(self, expand='ignored'):
  533. if self._all_square_blocks():
  534. return BlockDiagMatrix(*[mat.inverse() for mat in self.args])
  535. # See comment in _eval_determinant()
  536. raise NonInvertibleMatrixError('Matrix det == 0; not invertible.')
  537. def _eval_transpose(self):
  538. return BlockDiagMatrix(*[mat.transpose() for mat in self.args])
  539. def _blockmul(self, other):
  540. if (isinstance(other, BlockDiagMatrix) and
  541. self.colblocksizes == other.rowblocksizes):
  542. return BlockDiagMatrix(*[a*b for a, b in zip(self.args, other.args)])
  543. else:
  544. return BlockMatrix._blockmul(self, other)
  545. def _blockadd(self, other):
  546. if (isinstance(other, BlockDiagMatrix) and
  547. self.blockshape == other.blockshape and
  548. self.rowblocksizes == other.rowblocksizes and
  549. self.colblocksizes == other.colblocksizes):
  550. return BlockDiagMatrix(*[a + b for a, b in zip(self.args, other.args)])
  551. else:
  552. return BlockMatrix._blockadd(self, other)
  553. def get_diag_blocks(self):
  554. """Return the list of diagonal blocks of the matrix.
  555. Examples
  556. ========
  557. >>> from sympy import BlockDiagMatrix, Matrix
  558. >>> A = Matrix([[1, 2], [3, 4]])
  559. >>> B = Matrix([[5, 6], [7, 8]])
  560. >>> M = BlockDiagMatrix(A, B)
  561. How to get diagonal blocks from the block diagonal matrix:
  562. >>> diag_blocks = M.get_diag_blocks()
  563. >>> diag_blocks[0]
  564. Matrix([
  565. [1, 2],
  566. [3, 4]])
  567. >>> diag_blocks[1]
  568. Matrix([
  569. [5, 6],
  570. [7, 8]])
  571. """
  572. return self.args
  573. def block_collapse(expr):
  574. """Evaluates a block matrix expression
  575. >>> from sympy import MatrixSymbol, BlockMatrix, symbols, Identity, ZeroMatrix, block_collapse
  576. >>> n,m,l = symbols('n m l')
  577. >>> X = MatrixSymbol('X', n, n)
  578. >>> Y = MatrixSymbol('Y', m, m)
  579. >>> Z = MatrixSymbol('Z', n, m)
  580. >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m, n), Y]])
  581. >>> print(B)
  582. Matrix([
  583. [X, Z],
  584. [0, Y]])
  585. >>> C = BlockMatrix([[Identity(n), Z]])
  586. >>> print(C)
  587. Matrix([[I, Z]])
  588. >>> print(block_collapse(C*B))
  589. Matrix([[X, Z + Z*Y]])
  590. """
  591. from sympy.strategies.util import expr_fns
  592. hasbm = lambda expr: isinstance(expr, MatrixExpr) and expr.has(BlockMatrix)
  593. conditioned_rl = condition(
  594. hasbm,
  595. typed(
  596. {MatAdd: do_one(bc_matadd, bc_block_plus_ident),
  597. MatMul: do_one(bc_matmul, bc_dist),
  598. MatPow: bc_matmul,
  599. Transpose: bc_transpose,
  600. Inverse: bc_inverse,
  601. BlockMatrix: do_one(bc_unpack, deblock)}
  602. )
  603. )
  604. rule = exhaust(
  605. bottom_up(
  606. exhaust(conditioned_rl),
  607. fns=expr_fns
  608. )
  609. )
  610. result = rule(expr)
  611. doit = getattr(result, 'doit', None)
  612. if doit is not None:
  613. return doit()
  614. else:
  615. return result
  616. def bc_unpack(expr):
  617. if expr.blockshape == (1, 1):
  618. return expr.blocks[0, 0]
  619. return expr
  620. def bc_matadd(expr):
  621. args = sift(expr.args, lambda M: isinstance(M, BlockMatrix))
  622. blocks = args[True]
  623. if not blocks:
  624. return expr
  625. nonblocks = args[False]
  626. block = blocks[0]
  627. for b in blocks[1:]:
  628. block = block._blockadd(b)
  629. if nonblocks:
  630. return MatAdd(*nonblocks) + block
  631. else:
  632. return block
  633. def bc_block_plus_ident(expr):
  634. idents = [arg for arg in expr.args if arg.is_Identity]
  635. if not idents:
  636. return expr
  637. blocks = [arg for arg in expr.args if isinstance(arg, BlockMatrix)]
  638. if (blocks and all(b.structurally_equal(blocks[0]) for b in blocks)
  639. and blocks[0].is_structurally_symmetric):
  640. block_id = BlockDiagMatrix(*[Identity(k)
  641. for k in blocks[0].rowblocksizes])
  642. rest = [arg for arg in expr.args if not arg.is_Identity and not isinstance(arg, BlockMatrix)]
  643. return MatAdd(block_id * len(idents), *blocks, *rest).doit()
  644. return expr
  645. def bc_dist(expr):
  646. """ Turn a*[X, Y] into [a*X, a*Y] """
  647. factor, mat = expr.as_coeff_mmul()
  648. if factor == 1:
  649. return expr
  650. unpacked = unpack(mat)
  651. if isinstance(unpacked, BlockDiagMatrix):
  652. B = unpacked.diag
  653. new_B = [factor * mat for mat in B]
  654. return BlockDiagMatrix(*new_B)
  655. elif isinstance(unpacked, BlockMatrix):
  656. B = unpacked.blocks
  657. new_B = [
  658. [factor * B[i, j] for j in range(B.cols)] for i in range(B.rows)]
  659. return BlockMatrix(new_B)
  660. return expr
  661. def bc_matmul(expr):
  662. if isinstance(expr, MatPow):
  663. if expr.args[1].is_Integer and expr.args[1] > 0:
  664. factor, matrices = 1, [expr.args[0]]*expr.args[1]
  665. else:
  666. return expr
  667. else:
  668. factor, matrices = expr.as_coeff_matrices()
  669. i = 0
  670. while (i+1 < len(matrices)):
  671. A, B = matrices[i:i+2]
  672. if isinstance(A, BlockMatrix) and isinstance(B, BlockMatrix):
  673. matrices[i] = A._blockmul(B)
  674. matrices.pop(i+1)
  675. elif isinstance(A, BlockMatrix):
  676. matrices[i] = A._blockmul(BlockMatrix([[B]]))
  677. matrices.pop(i+1)
  678. elif isinstance(B, BlockMatrix):
  679. matrices[i] = BlockMatrix([[A]])._blockmul(B)
  680. matrices.pop(i+1)
  681. else:
  682. i+=1
  683. return MatMul(factor, *matrices).doit()
  684. def bc_transpose(expr):
  685. collapse = block_collapse(expr.arg)
  686. return collapse._eval_transpose()
  687. def bc_inverse(expr):
  688. if isinstance(expr.arg, BlockDiagMatrix):
  689. return expr.inverse()
  690. expr2 = blockinverse_1x1(expr)
  691. if expr != expr2:
  692. return expr2
  693. return blockinverse_2x2(Inverse(reblock_2x2(expr.arg)))
  694. def blockinverse_1x1(expr):
  695. if isinstance(expr.arg, BlockMatrix) and expr.arg.blockshape == (1, 1):
  696. mat = Matrix([[expr.arg.blocks[0].inverse()]])
  697. return BlockMatrix(mat)
  698. return expr
  699. def blockinverse_2x2(expr):
  700. if isinstance(expr.arg, BlockMatrix) and expr.arg.blockshape == (2, 2):
  701. # See: Inverses of 2x2 Block Matrices, Tzon-Tzer Lu and Sheng-Hua Shiou
  702. [[A, B],
  703. [C, D]] = expr.arg.blocks.tolist()
  704. formula = _choose_2x2_inversion_formula(A, B, C, D)
  705. if formula != None:
  706. MI = expr.arg.schur(formula).I
  707. if formula == 'A':
  708. AI = A.I
  709. return BlockMatrix([[AI + AI * B * MI * C * AI, -AI * B * MI], [-MI * C * AI, MI]])
  710. if formula == 'B':
  711. BI = B.I
  712. return BlockMatrix([[-MI * D * BI, MI], [BI + BI * A * MI * D * BI, -BI * A * MI]])
  713. if formula == 'C':
  714. CI = C.I
  715. return BlockMatrix([[-CI * D * MI, CI + CI * D * MI * A * CI], [MI, -MI * A * CI]])
  716. if formula == 'D':
  717. DI = D.I
  718. return BlockMatrix([[MI, -MI * B * DI], [-DI * C * MI, DI + DI * C * MI * B * DI]])
  719. return expr
  720. def _choose_2x2_inversion_formula(A, B, C, D):
  721. """
  722. Assuming [[A, B], [C, D]] would form a valid square block matrix, find
  723. which of the classical 2x2 block matrix inversion formulas would be
  724. best suited.
  725. Returns 'A', 'B', 'C', 'D' to represent the algorithm involving inversion
  726. of the given argument or None if the matrix cannot be inverted using
  727. any of those formulas.
  728. """
  729. # Try to find a known invertible matrix. Note that the Schur complement
  730. # is currently not being considered for this
  731. A_inv = ask(Q.invertible(A))
  732. if A_inv == True:
  733. return 'A'
  734. B_inv = ask(Q.invertible(B))
  735. if B_inv == True:
  736. return 'B'
  737. C_inv = ask(Q.invertible(C))
  738. if C_inv == True:
  739. return 'C'
  740. D_inv = ask(Q.invertible(D))
  741. if D_inv == True:
  742. return 'D'
  743. # Otherwise try to find a matrix that isn't known to be non-invertible
  744. if A_inv != False:
  745. return 'A'
  746. if B_inv != False:
  747. return 'B'
  748. if C_inv != False:
  749. return 'C'
  750. if D_inv != False:
  751. return 'D'
  752. return None
  753. def deblock(B):
  754. """ Flatten a BlockMatrix of BlockMatrices """
  755. if not isinstance(B, BlockMatrix) or not B.blocks.has(BlockMatrix):
  756. return B
  757. wrap = lambda x: x if isinstance(x, BlockMatrix) else BlockMatrix([[x]])
  758. bb = B.blocks.applyfunc(wrap) # everything is a block
  759. try:
  760. MM = Matrix(0, sum(bb[0, i].blocks.shape[1] for i in range(bb.shape[1])), [])
  761. for row in range(0, bb.shape[0]):
  762. M = Matrix(bb[row, 0].blocks)
  763. for col in range(1, bb.shape[1]):
  764. M = M.row_join(bb[row, col].blocks)
  765. MM = MM.col_join(M)
  766. return BlockMatrix(MM)
  767. except ShapeError:
  768. return B
  769. def reblock_2x2(expr):
  770. """
  771. Reblock a BlockMatrix so that it has 2x2 blocks of block matrices. If
  772. possible in such a way that the matrix continues to be invertible using the
  773. classical 2x2 block inversion formulas.
  774. """
  775. if not isinstance(expr, BlockMatrix) or not all(d > 2 for d in expr.blockshape):
  776. return expr
  777. BM = BlockMatrix # for brevity's sake
  778. rowblocks, colblocks = expr.blockshape
  779. blocks = expr.blocks
  780. for i in range(1, rowblocks):
  781. for j in range(1, colblocks):
  782. # try to split rows at i and cols at j
  783. A = bc_unpack(BM(blocks[:i, :j]))
  784. B = bc_unpack(BM(blocks[:i, j:]))
  785. C = bc_unpack(BM(blocks[i:, :j]))
  786. D = bc_unpack(BM(blocks[i:, j:]))
  787. formula = _choose_2x2_inversion_formula(A, B, C, D)
  788. if formula is not None:
  789. return BlockMatrix([[A, B], [C, D]])
  790. # else: nothing worked, just split upper left corner
  791. return BM([[blocks[0, 0], BM(blocks[0, 1:])],
  792. [BM(blocks[1:, 0]), BM(blocks[1:, 1:])]])
  793. def bounds(sizes):
  794. """ Convert sequence of numbers into pairs of low-high pairs
  795. >>> from sympy.matrices.expressions.blockmatrix import bounds
  796. >>> bounds((1, 10, 50))
  797. [(0, 1), (1, 11), (11, 61)]
  798. """
  799. low = 0
  800. rv = []
  801. for size in sizes:
  802. rv.append((low, low + size))
  803. low += size
  804. return rv
  805. def blockcut(expr, rowsizes, colsizes):
  806. """ Cut a matrix expression into Blocks
  807. >>> from sympy import ImmutableMatrix, blockcut
  808. >>> M = ImmutableMatrix(4, 4, range(16))
  809. >>> B = blockcut(M, (1, 3), (1, 3))
  810. >>> type(B).__name__
  811. 'BlockMatrix'
  812. >>> ImmutableMatrix(B.blocks[0, 1])
  813. Matrix([[1, 2, 3]])
  814. """
  815. rowbounds = bounds(rowsizes)
  816. colbounds = bounds(colsizes)
  817. return BlockMatrix([[MatrixSlice(expr, rowbound, colbound)
  818. for colbound in colbounds]
  819. for rowbound in rowbounds])