test_blockmatrix.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. from sympy.matrices.expressions.trace import Trace
  2. from sympy.testing.pytest import raises, slow
  3. from sympy.matrices.expressions.blockmatrix import (
  4. block_collapse, bc_matmul, bc_block_plus_ident, BlockDiagMatrix,
  5. BlockMatrix, bc_dist, bc_matadd, bc_transpose, bc_inverse,
  6. blockcut, reblock_2x2, deblock)
  7. from sympy.matrices.expressions import (
  8. MatrixSymbol, Identity, trace, det, ZeroMatrix, OneMatrix)
  9. from sympy.matrices.expressions.inverse import Inverse
  10. from sympy.matrices.expressions.matpow import MatPow
  11. from sympy.matrices.expressions.transpose import Transpose
  12. from sympy.matrices.exceptions import NonInvertibleMatrixError
  13. from sympy.matrices import (
  14. Matrix, ImmutableMatrix, ImmutableSparseMatrix, zeros)
  15. from sympy.core import Tuple, Expr, S, Function
  16. from sympy.core.symbol import Symbol, symbols
  17. from sympy.functions import transpose, im, re
  18. i, j, k, l, m, n, p = symbols('i:n, p', integer=True)
  19. A = MatrixSymbol('A', n, n)
  20. B = MatrixSymbol('B', n, n)
  21. C = MatrixSymbol('C', n, n)
  22. D = MatrixSymbol('D', n, n)
  23. G = MatrixSymbol('G', n, n)
  24. H = MatrixSymbol('H', n, n)
  25. b1 = BlockMatrix([[G, H]])
  26. b2 = BlockMatrix([[G], [H]])
  27. def test_bc_matmul():
  28. assert bc_matmul(H*b1*b2*G) == BlockMatrix([[(H*G*G + H*H*H)*G]])
  29. def test_bc_matadd():
  30. assert bc_matadd(BlockMatrix([[G, H]]) + BlockMatrix([[H, H]])) == \
  31. BlockMatrix([[G+H, H+H]])
  32. def test_bc_transpose():
  33. assert bc_transpose(Transpose(BlockMatrix([[A, B], [C, D]]))) == \
  34. BlockMatrix([[A.T, C.T], [B.T, D.T]])
  35. def test_bc_dist_diag():
  36. A = MatrixSymbol('A', n, n)
  37. B = MatrixSymbol('B', m, m)
  38. C = MatrixSymbol('C', l, l)
  39. X = BlockDiagMatrix(A, B, C)
  40. assert bc_dist(X+X).equals(BlockDiagMatrix(2*A, 2*B, 2*C))
  41. def test_block_plus_ident():
  42. A = MatrixSymbol('A', n, n)
  43. B = MatrixSymbol('B', n, m)
  44. C = MatrixSymbol('C', m, n)
  45. D = MatrixSymbol('D', m, m)
  46. X = BlockMatrix([[A, B], [C, D]])
  47. Z = MatrixSymbol('Z', n + m, n + m)
  48. assert bc_block_plus_ident(X + Identity(m + n) + Z) == \
  49. BlockDiagMatrix(Identity(n), Identity(m)) + X + Z
  50. def test_BlockMatrix():
  51. A = MatrixSymbol('A', n, m)
  52. B = MatrixSymbol('B', n, k)
  53. C = MatrixSymbol('C', l, m)
  54. D = MatrixSymbol('D', l, k)
  55. M = MatrixSymbol('M', m + k, p)
  56. N = MatrixSymbol('N', l + n, k + m)
  57. X = BlockMatrix(Matrix([[A, B], [C, D]]))
  58. assert X.__class__(*X.args) == X
  59. # block_collapse does nothing on normal inputs
  60. E = MatrixSymbol('E', n, m)
  61. assert block_collapse(A + 2*E) == A + 2*E
  62. F = MatrixSymbol('F', m, m)
  63. assert block_collapse(E.T*A*F) == E.T*A*F
  64. assert X.shape == (l + n, k + m)
  65. assert X.blockshape == (2, 2)
  66. assert transpose(X) == BlockMatrix(Matrix([[A.T, C.T], [B.T, D.T]]))
  67. assert transpose(X).shape == X.shape[::-1]
  68. # Test that BlockMatrices and MatrixSymbols can still mix
  69. assert (X*M).is_MatMul
  70. assert X._blockmul(M).is_MatMul
  71. assert (X*M).shape == (n + l, p)
  72. assert (X + N).is_MatAdd
  73. assert X._blockadd(N).is_MatAdd
  74. assert (X + N).shape == X.shape
  75. E = MatrixSymbol('E', m, 1)
  76. F = MatrixSymbol('F', k, 1)
  77. Y = BlockMatrix(Matrix([[E], [F]]))
  78. assert (X*Y).shape == (l + n, 1)
  79. assert block_collapse(X*Y).blocks[0, 0] == A*E + B*F
  80. assert block_collapse(X*Y).blocks[1, 0] == C*E + D*F
  81. # block_collapse passes down into container objects, transposes, and inverse
  82. assert block_collapse(transpose(X*Y)) == transpose(block_collapse(X*Y))
  83. assert block_collapse(Tuple(X*Y, 2*X)) == (
  84. block_collapse(X*Y), block_collapse(2*X))
  85. # Make sure that MatrixSymbols will enter 1x1 BlockMatrix if it simplifies
  86. Ab = BlockMatrix([[A]])
  87. Z = MatrixSymbol('Z', *A.shape)
  88. assert block_collapse(Ab + Z) == A + Z
  89. def test_block_collapse_explicit_matrices():
  90. A = Matrix([[1, 2], [3, 4]])
  91. assert block_collapse(BlockMatrix([[A]])) == A
  92. A = ImmutableSparseMatrix([[1, 2], [3, 4]])
  93. assert block_collapse(BlockMatrix([[A]])) == A
  94. def test_issue_17624():
  95. a = MatrixSymbol("a", 2, 2)
  96. z = ZeroMatrix(2, 2)
  97. b = BlockMatrix([[a, z], [z, z]])
  98. assert block_collapse(b * b) == BlockMatrix([[a**2, z], [z, z]])
  99. assert block_collapse(b * b * b) == BlockMatrix([[a**3, z], [z, z]])
  100. def test_issue_18618():
  101. A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  102. assert A == Matrix(BlockDiagMatrix(A))
  103. def test_BlockMatrix_trace():
  104. A, B, C, D = [MatrixSymbol(s, 3, 3) for s in 'ABCD']
  105. X = BlockMatrix([[A, B], [C, D]])
  106. assert trace(X) == trace(A) + trace(D)
  107. assert trace(BlockMatrix([ZeroMatrix(n, n)])) == 0
  108. def test_BlockMatrix_Determinant():
  109. A, B, C, D = [MatrixSymbol(s, 3, 3) for s in 'ABCD']
  110. X = BlockMatrix([[A, B], [C, D]])
  111. from sympy.assumptions.ask import Q
  112. from sympy.assumptions.assume import assuming
  113. with assuming(Q.invertible(A)):
  114. assert det(X) == det(A) * det(X.schur('A'))
  115. assert isinstance(det(X), Expr)
  116. assert det(BlockMatrix([A])) == det(A)
  117. assert det(BlockMatrix([ZeroMatrix(n, n)])) == 0
  118. def test_squareBlockMatrix():
  119. A = MatrixSymbol('A', n, n)
  120. B = MatrixSymbol('B', n, m)
  121. C = MatrixSymbol('C', m, n)
  122. D = MatrixSymbol('D', m, m)
  123. X = BlockMatrix([[A, B], [C, D]])
  124. Y = BlockMatrix([[A]])
  125. assert X.is_square
  126. Q = X + Identity(m + n)
  127. assert (block_collapse(Q) ==
  128. BlockMatrix([[A + Identity(n), B], [C, D + Identity(m)]]))
  129. assert (X + MatrixSymbol('Q', n + m, n + m)).is_MatAdd
  130. assert (X * MatrixSymbol('Q', n + m, n + m)).is_MatMul
  131. assert block_collapse(Y.I) == A.I
  132. assert isinstance(X.inverse(), Inverse)
  133. assert not X.is_Identity
  134. Z = BlockMatrix([[Identity(n), B], [C, D]])
  135. assert not Z.is_Identity
  136. def test_BlockMatrix_2x2_inverse_symbolic():
  137. A = MatrixSymbol('A', n, m)
  138. B = MatrixSymbol('B', n, k - m)
  139. C = MatrixSymbol('C', k - n, m)
  140. D = MatrixSymbol('D', k - n, k - m)
  141. X = BlockMatrix([[A, B], [C, D]])
  142. assert X.is_square and X.shape == (k, k)
  143. assert isinstance(block_collapse(X.I), Inverse) # Can't invert when none of the blocks is square
  144. # test code path where only A is invertible
  145. A = MatrixSymbol('A', n, n)
  146. B = MatrixSymbol('B', n, m)
  147. C = MatrixSymbol('C', m, n)
  148. D = ZeroMatrix(m, m)
  149. X = BlockMatrix([[A, B], [C, D]])
  150. assert block_collapse(X.inverse()) == BlockMatrix([
  151. [A.I + A.I * B * X.schur('A').I * C * A.I, -A.I * B * X.schur('A').I],
  152. [-X.schur('A').I * C * A.I, X.schur('A').I],
  153. ])
  154. # test code path where only B is invertible
  155. A = MatrixSymbol('A', n, m)
  156. B = MatrixSymbol('B', n, n)
  157. C = ZeroMatrix(m, m)
  158. D = MatrixSymbol('D', m, n)
  159. X = BlockMatrix([[A, B], [C, D]])
  160. assert block_collapse(X.inverse()) == BlockMatrix([
  161. [-X.schur('B').I * D * B.I, X.schur('B').I],
  162. [B.I + B.I * A * X.schur('B').I * D * B.I, -B.I * A * X.schur('B').I],
  163. ])
  164. # test code path where only C is invertible
  165. A = MatrixSymbol('A', n, m)
  166. B = ZeroMatrix(n, n)
  167. C = MatrixSymbol('C', m, m)
  168. D = MatrixSymbol('D', m, n)
  169. X = BlockMatrix([[A, B], [C, D]])
  170. assert block_collapse(X.inverse()) == BlockMatrix([
  171. [-C.I * D * X.schur('C').I, C.I + C.I * D * X.schur('C').I * A * C.I],
  172. [X.schur('C').I, -X.schur('C').I * A * C.I],
  173. ])
  174. # test code path where only D is invertible
  175. A = ZeroMatrix(n, n)
  176. B = MatrixSymbol('B', n, m)
  177. C = MatrixSymbol('C', m, n)
  178. D = MatrixSymbol('D', m, m)
  179. X = BlockMatrix([[A, B], [C, D]])
  180. assert block_collapse(X.inverse()) == BlockMatrix([
  181. [X.schur('D').I, -X.schur('D').I * B * D.I],
  182. [-D.I * C * X.schur('D').I, D.I + D.I * C * X.schur('D').I * B * D.I],
  183. ])
  184. def test_BlockMatrix_2x2_inverse_numeric():
  185. """Test 2x2 block matrix inversion numerically for all 4 formulas"""
  186. M = Matrix([[1, 2], [3, 4]])
  187. # rank deficient matrices that have full rank when two of them combined
  188. D1 = Matrix([[1, 2], [2, 4]])
  189. D2 = Matrix([[1, 3], [3, 9]])
  190. D3 = Matrix([[1, 4], [4, 16]])
  191. assert D1.rank() == D2.rank() == D3.rank() == 1
  192. assert (D1 + D2).rank() == (D2 + D3).rank() == (D3 + D1).rank() == 2
  193. # Only A is invertible
  194. K = BlockMatrix([[M, D1], [D2, D3]])
  195. assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv()
  196. # Only B is invertible
  197. K = BlockMatrix([[D1, M], [D2, D3]])
  198. assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv()
  199. # Only C is invertible
  200. K = BlockMatrix([[D1, D2], [M, D3]])
  201. assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv()
  202. # Only D is invertible
  203. K = BlockMatrix([[D1, D2], [D3, M]])
  204. assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv()
  205. @slow
  206. def test_BlockMatrix_3x3_symbolic():
  207. # Only test one of these, instead of all permutations, because it's slow
  208. rowblocksizes = (n, m, k)
  209. colblocksizes = (m, k, n)
  210. K = BlockMatrix([
  211. [MatrixSymbol('M%s%s' % (rows, cols), rows, cols) for cols in colblocksizes]
  212. for rows in rowblocksizes
  213. ])
  214. collapse = block_collapse(K.I)
  215. assert isinstance(collapse, BlockMatrix)
  216. def test_BlockDiagMatrix():
  217. A = MatrixSymbol('A', n, n)
  218. B = MatrixSymbol('B', m, m)
  219. C = MatrixSymbol('C', l, l)
  220. M = MatrixSymbol('M', n + m + l, n + m + l)
  221. X = BlockDiagMatrix(A, B, C)
  222. Y = BlockDiagMatrix(A, 2*B, 3*C)
  223. assert X.blocks[1, 1] == B
  224. assert X.shape == (n + m + l, n + m + l)
  225. assert all(X.blocks[i, j].is_ZeroMatrix if i != j else X.blocks[i, j] in [A, B, C]
  226. for i in range(3) for j in range(3))
  227. assert X.__class__(*X.args) == X
  228. assert X.get_diag_blocks() == (A, B, C)
  229. assert isinstance(block_collapse(X.I * X), Identity)
  230. assert bc_matmul(X*X) == BlockDiagMatrix(A*A, B*B, C*C)
  231. assert block_collapse(X*X) == BlockDiagMatrix(A*A, B*B, C*C)
  232. #XXX: should be == ??
  233. assert block_collapse(X + X).equals(BlockDiagMatrix(2*A, 2*B, 2*C))
  234. assert block_collapse(X*Y) == BlockDiagMatrix(A*A, 2*B*B, 3*C*C)
  235. assert block_collapse(X + Y) == BlockDiagMatrix(2*A, 3*B, 4*C)
  236. # Ensure that BlockDiagMatrices can still interact with normal MatrixExprs
  237. assert (X*(2*M)).is_MatMul
  238. assert (X + (2*M)).is_MatAdd
  239. assert (X._blockmul(M)).is_MatMul
  240. assert (X._blockadd(M)).is_MatAdd
  241. def test_BlockDiagMatrix_nonsquare():
  242. A = MatrixSymbol('A', n, m)
  243. B = MatrixSymbol('B', k, l)
  244. X = BlockDiagMatrix(A, B)
  245. assert X.shape == (n + k, m + l)
  246. assert X.shape == (n + k, m + l)
  247. assert X.rowblocksizes == [n, k]
  248. assert X.colblocksizes == [m, l]
  249. C = MatrixSymbol('C', n, m)
  250. D = MatrixSymbol('D', k, l)
  251. Y = BlockDiagMatrix(C, D)
  252. assert block_collapse(X + Y) == BlockDiagMatrix(A + C, B + D)
  253. assert block_collapse(X * Y.T) == BlockDiagMatrix(A * C.T, B * D.T)
  254. raises(NonInvertibleMatrixError, lambda: BlockDiagMatrix(A, C.T).inverse())
  255. def test_BlockDiagMatrix_determinant():
  256. A = MatrixSymbol('A', n, n)
  257. B = MatrixSymbol('B', m, m)
  258. assert det(BlockDiagMatrix()) == 1
  259. assert det(BlockDiagMatrix(A)) == det(A)
  260. assert det(BlockDiagMatrix(A, B)) == det(A) * det(B)
  261. # non-square blocks
  262. C = MatrixSymbol('C', m, n)
  263. D = MatrixSymbol('D', n, m)
  264. assert det(BlockDiagMatrix(C, D)) == 0
  265. def test_BlockDiagMatrix_trace():
  266. assert trace(BlockDiagMatrix()) == 0
  267. assert trace(BlockDiagMatrix(ZeroMatrix(n, n))) == 0
  268. A = MatrixSymbol('A', n, n)
  269. assert trace(BlockDiagMatrix(A)) == trace(A)
  270. B = MatrixSymbol('B', m, m)
  271. assert trace(BlockDiagMatrix(A, B)) == trace(A) + trace(B)
  272. # non-square blocks
  273. C = MatrixSymbol('C', m, n)
  274. D = MatrixSymbol('D', n, m)
  275. assert isinstance(trace(BlockDiagMatrix(C, D)), Trace)
  276. def test_BlockDiagMatrix_transpose():
  277. A = MatrixSymbol('A', n, m)
  278. B = MatrixSymbol('B', k, l)
  279. assert transpose(BlockDiagMatrix()) == BlockDiagMatrix()
  280. assert transpose(BlockDiagMatrix(A)) == BlockDiagMatrix(A.T)
  281. assert transpose(BlockDiagMatrix(A, B)) == BlockDiagMatrix(A.T, B.T)
  282. def test_issue_2460():
  283. bdm1 = BlockDiagMatrix(Matrix([i]), Matrix([j]))
  284. bdm2 = BlockDiagMatrix(Matrix([k]), Matrix([l]))
  285. assert block_collapse(bdm1 + bdm2) == BlockDiagMatrix(Matrix([i + k]), Matrix([j + l]))
  286. def test_blockcut():
  287. A = MatrixSymbol('A', n, m)
  288. B = blockcut(A, (n/2, n/2), (m/2, m/2))
  289. assert B == BlockMatrix([[A[:n/2, :m/2], A[:n/2, m/2:]],
  290. [A[n/2:, :m/2], A[n/2:, m/2:]]])
  291. M = ImmutableMatrix(4, 4, range(16))
  292. B = blockcut(M, (2, 2), (2, 2))
  293. assert M == ImmutableMatrix(B)
  294. B = blockcut(M, (1, 3), (2, 2))
  295. assert ImmutableMatrix(B.blocks[0, 1]) == ImmutableMatrix([[2, 3]])
  296. def test_reblock_2x2():
  297. B = BlockMatrix([[MatrixSymbol('A_%d%d'%(i,j), 2, 2)
  298. for j in range(3)]
  299. for i in range(3)])
  300. assert B.blocks.shape == (3, 3)
  301. BB = reblock_2x2(B)
  302. assert BB.blocks.shape == (2, 2)
  303. assert B.shape == BB.shape
  304. assert B.as_explicit() == BB.as_explicit()
  305. def test_deblock():
  306. B = BlockMatrix([[MatrixSymbol('A_%d%d'%(i,j), n, n)
  307. for j in range(4)]
  308. for i in range(4)])
  309. assert deblock(reblock_2x2(B)) == B
  310. def test_block_collapse_type():
  311. bm1 = BlockDiagMatrix(ImmutableMatrix([1]), ImmutableMatrix([2]))
  312. bm2 = BlockDiagMatrix(ImmutableMatrix([3]), ImmutableMatrix([4]))
  313. assert bm1.T.__class__ == BlockDiagMatrix
  314. assert block_collapse(bm1 - bm2).__class__ == BlockDiagMatrix
  315. assert block_collapse(Inverse(bm1)).__class__ == BlockDiagMatrix
  316. assert block_collapse(Transpose(bm1)).__class__ == BlockDiagMatrix
  317. assert bc_transpose(Transpose(bm1)).__class__ == BlockDiagMatrix
  318. assert bc_inverse(Inverse(bm1)).__class__ == BlockDiagMatrix
  319. def test_invalid_block_matrix():
  320. raises(ValueError, lambda: BlockMatrix([
  321. [Identity(2), Identity(5)],
  322. ]))
  323. raises(ValueError, lambda: BlockMatrix([
  324. [Identity(n), Identity(m)],
  325. ]))
  326. raises(ValueError, lambda: BlockMatrix([
  327. [ZeroMatrix(n, n), ZeroMatrix(n, n)],
  328. [ZeroMatrix(n, n - 1), ZeroMatrix(n, n + 1)],
  329. ]))
  330. raises(ValueError, lambda: BlockMatrix([
  331. [ZeroMatrix(n - 1, n), ZeroMatrix(n, n)],
  332. [ZeroMatrix(n + 1, n), ZeroMatrix(n, n)],
  333. ]))
  334. def test_block_lu_decomposition():
  335. A = MatrixSymbol('A', n, n)
  336. B = MatrixSymbol('B', n, m)
  337. C = MatrixSymbol('C', m, n)
  338. D = MatrixSymbol('D', m, m)
  339. X = BlockMatrix([[A, B], [C, D]])
  340. #LDU decomposition
  341. L, D, U = X.LDUdecomposition()
  342. assert block_collapse(L*D*U) == X
  343. #UDL decomposition
  344. U, D, L = X.UDLdecomposition()
  345. assert block_collapse(U*D*L) == X
  346. #LU decomposition
  347. L, U = X.LUdecomposition()
  348. assert block_collapse(L*U) == X
  349. def test_issue_21866():
  350. n = 10
  351. I = Identity(n)
  352. O = ZeroMatrix(n, n)
  353. A = BlockMatrix([[ I, O, O, O ],
  354. [ O, I, O, O ],
  355. [ O, O, I, O ],
  356. [ I, O, O, I ]])
  357. Ainv = block_collapse(A.inv())
  358. AinvT = BlockMatrix([[ I, O, O, O ],
  359. [ O, I, O, O ],
  360. [ O, O, I, O ],
  361. [ -I, O, O, I ]])
  362. assert Ainv == AinvT
  363. def test_adjoint_and_special_matrices():
  364. A = Identity(3)
  365. B = OneMatrix(3, 2)
  366. C = ZeroMatrix(2, 3)
  367. D = Identity(2)
  368. X = BlockMatrix([[A, B], [C, D]])
  369. X2 = BlockMatrix([[A, S.ImaginaryUnit*B], [C, D]])
  370. assert X.adjoint() == BlockMatrix([[A, ZeroMatrix(3, 2)], [OneMatrix(2, 3), D]])
  371. assert re(X) == X
  372. assert X2.adjoint() == BlockMatrix([[A, ZeroMatrix(3, 2)], [-S.ImaginaryUnit*OneMatrix(2, 3), D]])
  373. assert im(X2) == BlockMatrix([[ZeroMatrix(3, 3), OneMatrix(3, 2)], [ZeroMatrix(2, 3), ZeroMatrix(2, 2)]])
  374. def test_block_matrix_derivative():
  375. x = symbols('x')
  376. A = Matrix(3, 3, [Function(f'a{i}')(x) for i in range(9)])
  377. bc = BlockMatrix([[A[:2, :2], A[:2, 2]], [A[2, :2], A[2:, 2]]])
  378. assert Matrix(bc.diff(x)) - A.diff(x) == zeros(3, 3)
  379. def test_transpose_inverse_commute():
  380. n = Symbol('n')
  381. I = Identity(n)
  382. Z = ZeroMatrix(n, n)
  383. A = BlockMatrix([[I, Z], [Z, I]])
  384. assert block_collapse(A.transpose().inverse()) == A
  385. assert block_collapse(A.inverse().transpose()) == A
  386. assert block_collapse(MatPow(A.transpose(), -2)) == MatPow(A, -2)
  387. assert block_collapse(MatPow(A, -2).transpose()) == MatPow(A, -2)