matmul.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. from sympy.assumptions.ask import ask, Q
  2. from sympy.assumptions.refine import handlers_dict
  3. from sympy.core import Basic, sympify, S
  4. from sympy.core.mul import mul, Mul
  5. from sympy.core.numbers import Number, Integer
  6. from sympy.core.symbol import Dummy
  7. from sympy.functions import adjoint
  8. from sympy.strategies import (rm_id, unpack, typed, flatten, exhaust,
  9. do_one, new)
  10. from sympy.matrices.exceptions import NonInvertibleMatrixError
  11. from sympy.matrices.matrixbase import MatrixBase
  12. from sympy.utilities.exceptions import sympy_deprecation_warning
  13. from sympy.matrices.expressions._shape import validate_matmul_integer as validate
  14. from .inverse import Inverse
  15. from .matexpr import MatrixExpr
  16. from .matpow import MatPow
  17. from .transpose import transpose
  18. from .permutation import PermutationMatrix
  19. from .special import ZeroMatrix, Identity, GenericIdentity, OneMatrix
  20. # XXX: MatMul should perhaps not subclass directly from Mul
  21. class MatMul(MatrixExpr, Mul):
  22. """
  23. A product of matrix expressions
  24. Examples
  25. ========
  26. >>> from sympy import MatMul, MatrixSymbol
  27. >>> A = MatrixSymbol('A', 5, 4)
  28. >>> B = MatrixSymbol('B', 4, 3)
  29. >>> C = MatrixSymbol('C', 3, 6)
  30. >>> MatMul(A, B, C)
  31. A*B*C
  32. """
  33. is_MatMul = True
  34. identity = GenericIdentity()
  35. def __new__(cls, *args, evaluate=False, check=None, _sympify=True):
  36. if not args:
  37. return cls.identity
  38. # This must be removed aggressively in the constructor to avoid
  39. # TypeErrors from GenericIdentity().shape
  40. args = list(filter(lambda i: cls.identity != i, args))
  41. if _sympify:
  42. args = list(map(sympify, args))
  43. obj = Basic.__new__(cls, *args)
  44. factor, matrices = obj.as_coeff_matrices()
  45. if check is not None:
  46. sympy_deprecation_warning(
  47. "Passing check to MatMul is deprecated and the check argument will be removed in a future version.",
  48. deprecated_since_version="1.11",
  49. active_deprecations_target='remove-check-argument-from-matrix-operations')
  50. if check is not False:
  51. validate(*matrices)
  52. if not matrices:
  53. # Should it be
  54. #
  55. # return Basic.__neq__(cls, factor, GenericIdentity()) ?
  56. return factor
  57. if evaluate:
  58. return cls._evaluate(obj)
  59. return obj
  60. @classmethod
  61. def _evaluate(cls, expr):
  62. return canonicalize(expr)
  63. @property
  64. def shape(self):
  65. matrices = [arg for arg in self.args if arg.is_Matrix]
  66. return (matrices[0].rows, matrices[-1].cols)
  67. def _entry(self, i, j, expand=True, **kwargs):
  68. # Avoid cyclic imports
  69. from sympy.concrete.summations import Sum
  70. from sympy.matrices.immutable import ImmutableMatrix
  71. coeff, matrices = self.as_coeff_matrices()
  72. if len(matrices) == 1: # situation like 2*X, matmul is just X
  73. return coeff * matrices[0][i, j]
  74. indices = [None]*(len(matrices) + 1)
  75. ind_ranges = [None]*(len(matrices) - 1)
  76. indices[0] = i
  77. indices[-1] = j
  78. def f():
  79. counter = 1
  80. while True:
  81. yield Dummy("i_%i" % counter)
  82. counter += 1
  83. dummy_generator = kwargs.get("dummy_generator", f())
  84. for i in range(1, len(matrices)):
  85. indices[i] = next(dummy_generator)
  86. for i, arg in enumerate(matrices[:-1]):
  87. ind_ranges[i] = arg.shape[1] - 1
  88. matrices = [arg._entry(indices[i], indices[i+1], dummy_generator=dummy_generator) for i, arg in enumerate(matrices)]
  89. expr_in_sum = Mul.fromiter(matrices)
  90. if any(v.has(ImmutableMatrix) for v in matrices):
  91. expand = True
  92. result = coeff*Sum(
  93. expr_in_sum,
  94. *zip(indices[1:-1], [0]*len(ind_ranges), ind_ranges)
  95. )
  96. # Don't waste time in result.doit() if the sum bounds are symbolic
  97. if not any(isinstance(v, (Integer, int)) for v in ind_ranges):
  98. expand = False
  99. return result.doit() if expand else result
  100. def as_coeff_matrices(self):
  101. scalars = [x for x in self.args if not x.is_Matrix]
  102. matrices = [x for x in self.args if x.is_Matrix]
  103. coeff = Mul(*scalars)
  104. if coeff.is_commutative is False:
  105. raise NotImplementedError("noncommutative scalars in MatMul are not supported.")
  106. return coeff, matrices
  107. def as_coeff_mmul(self):
  108. coeff, matrices = self.as_coeff_matrices()
  109. return coeff, MatMul(*matrices)
  110. def expand(self, **kwargs):
  111. expanded = super(MatMul, self).expand(**kwargs)
  112. return self._evaluate(expanded)
  113. def _eval_transpose(self):
  114. """Transposition of matrix multiplication.
  115. Notes
  116. =====
  117. The following rules are applied.
  118. Transposition for matrix multiplied with another matrix:
  119. `\\left(A B\\right)^{T} = B^{T} A^{T}`
  120. Transposition for matrix multiplied with scalar:
  121. `\\left(c A\\right)^{T} = c A^{T}`
  122. References
  123. ==========
  124. .. [1] https://en.wikipedia.org/wiki/Transpose
  125. """
  126. coeff, matrices = self.as_coeff_matrices()
  127. return MatMul(
  128. coeff, *[transpose(arg) for arg in matrices[::-1]]).doit()
  129. def _eval_adjoint(self):
  130. return MatMul(*[adjoint(arg) for arg in self.args[::-1]]).doit()
  131. def _eval_trace(self):
  132. factor, mmul = self.as_coeff_mmul()
  133. if factor != 1:
  134. from .trace import trace
  135. return factor * trace(mmul.doit())
  136. def _eval_determinant(self):
  137. from sympy.matrices.expressions.determinant import Determinant
  138. factor, matrices = self.as_coeff_matrices()
  139. square_matrices = only_squares(*matrices)
  140. return factor**self.rows * Mul(*list(map(Determinant, square_matrices)))
  141. def _eval_inverse(self):
  142. if all(arg.is_square for arg in self.args if isinstance(arg, MatrixExpr)):
  143. return MatMul(*(
  144. arg.inverse() if isinstance(arg, MatrixExpr) else arg**-1
  145. for arg in self.args[::-1]
  146. )
  147. ).doit()
  148. return Inverse(self)
  149. def doit(self, **hints):
  150. deep = hints.get('deep', True)
  151. if deep:
  152. args = tuple(arg.doit(**hints) for arg in self.args)
  153. else:
  154. args = self.args
  155. # treat scalar*MatrixSymbol or scalar*MatPow separately
  156. expr = canonicalize(MatMul(*args))
  157. return expr
  158. # Needed for partial compatibility with Mul
  159. def args_cnc(self, cset=False, warn=True, **kwargs):
  160. coeff_c = [x for x in self.args if x.is_commutative]
  161. coeff_nc = [x for x in self.args if not x.is_commutative]
  162. if cset:
  163. clen = len(coeff_c)
  164. coeff_c = set(coeff_c)
  165. if clen and warn and len(coeff_c) != clen:
  166. raise ValueError('repeated commutative arguments: %s' %
  167. [ci for ci in coeff_c if list(self.args).count(ci) > 1])
  168. return [coeff_c, coeff_nc]
  169. def _eval_derivative_matrix_lines(self, x):
  170. from .transpose import Transpose
  171. with_x_ind = [i for i, arg in enumerate(self.args) if arg.has(x)]
  172. lines = []
  173. for ind in with_x_ind:
  174. left_args = self.args[:ind]
  175. right_args = self.args[ind+1:]
  176. if right_args:
  177. right_mat = MatMul.fromiter(right_args)
  178. else:
  179. right_mat = Identity(self.shape[1])
  180. if left_args:
  181. left_rev = MatMul.fromiter([Transpose(i).doit() if i.is_Matrix else i for i in reversed(left_args)])
  182. else:
  183. left_rev = Identity(self.shape[0])
  184. d = self.args[ind]._eval_derivative_matrix_lines(x)
  185. for i in d:
  186. i.append_first(left_rev)
  187. i.append_second(right_mat)
  188. lines.append(i)
  189. return lines
  190. mul.register_handlerclass((Mul, MatMul), MatMul)
  191. # Rules
  192. def newmul(*args):
  193. if args[0] == 1:
  194. args = args[1:]
  195. return new(MatMul, *args)
  196. def any_zeros(mul):
  197. if any(arg.is_zero or (arg.is_Matrix and arg.is_ZeroMatrix)
  198. for arg in mul.args):
  199. matrices = [arg for arg in mul.args if arg.is_Matrix]
  200. return ZeroMatrix(matrices[0].rows, matrices[-1].cols)
  201. return mul
  202. def merge_explicit(matmul):
  203. """ Merge explicit MatrixBase arguments
  204. >>> from sympy import MatrixSymbol, Matrix, MatMul, pprint
  205. >>> from sympy.matrices.expressions.matmul import merge_explicit
  206. >>> A = MatrixSymbol('A', 2, 2)
  207. >>> B = Matrix([[1, 1], [1, 1]])
  208. >>> C = Matrix([[1, 2], [3, 4]])
  209. >>> X = MatMul(A, B, C)
  210. >>> pprint(X)
  211. [1 1] [1 2]
  212. A*[ ]*[ ]
  213. [1 1] [3 4]
  214. >>> pprint(merge_explicit(X))
  215. [4 6]
  216. A*[ ]
  217. [4 6]
  218. >>> X = MatMul(B, A, C)
  219. >>> pprint(X)
  220. [1 1] [1 2]
  221. [ ]*A*[ ]
  222. [1 1] [3 4]
  223. >>> pprint(merge_explicit(X))
  224. [1 1] [1 2]
  225. [ ]*A*[ ]
  226. [1 1] [3 4]
  227. """
  228. if not any(isinstance(arg, MatrixBase) for arg in matmul.args):
  229. return matmul
  230. newargs = []
  231. last = matmul.args[0]
  232. for arg in matmul.args[1:]:
  233. if isinstance(arg, (MatrixBase, Number)) and isinstance(last, (MatrixBase, Number)):
  234. last = last * arg
  235. else:
  236. newargs.append(last)
  237. last = arg
  238. newargs.append(last)
  239. return MatMul(*newargs)
  240. def remove_ids(mul):
  241. """ Remove Identities from a MatMul
  242. This is a modified version of sympy.strategies.rm_id.
  243. This is necessary because MatMul may contain both MatrixExprs and Exprs
  244. as args.
  245. See Also
  246. ========
  247. sympy.strategies.rm_id
  248. """
  249. # Separate Exprs from MatrixExprs in args
  250. factor, mmul = mul.as_coeff_mmul()
  251. # Apply standard rm_id for MatMuls
  252. result = rm_id(lambda x: x.is_Identity is True)(mmul)
  253. if result != mmul:
  254. return newmul(factor, *result.args) # Recombine and return
  255. else:
  256. return mul
  257. def factor_in_front(mul):
  258. factor, matrices = mul.as_coeff_matrices()
  259. if factor != 1:
  260. return newmul(factor, *matrices)
  261. return mul
  262. def combine_powers(mul):
  263. r"""Combine consecutive powers with the same base into one, e.g.
  264. $$A \times A^2 \Rightarrow A^3$$
  265. This also cancels out the possible matrix inverses using the
  266. knowledgebase of :class:`~.Inverse`, e.g.,
  267. $$ Y \times X \times X^{-1} \Rightarrow Y $$
  268. """
  269. factor, args = mul.as_coeff_matrices()
  270. new_args = [args[0]]
  271. for i in range(1, len(args)):
  272. A = new_args[-1]
  273. B = args[i]
  274. if isinstance(B, Inverse) and isinstance(B.arg, MatMul):
  275. Bargs = B.arg.args
  276. l = len(Bargs)
  277. if list(Bargs) == new_args[-l:]:
  278. new_args = new_args[:-l] + [Identity(B.shape[0])]
  279. continue
  280. if isinstance(A, Inverse) and isinstance(A.arg, MatMul):
  281. Aargs = A.arg.args
  282. l = len(Aargs)
  283. if list(Aargs) == args[i:i+l]:
  284. identity = Identity(A.shape[0])
  285. new_args[-1] = identity
  286. for j in range(i, i+l):
  287. args[j] = identity
  288. continue
  289. if A.is_square == False or B.is_square == False:
  290. new_args.append(B)
  291. continue
  292. if isinstance(A, MatPow):
  293. A_base, A_exp = A.args
  294. else:
  295. A_base, A_exp = A, S.One
  296. if isinstance(B, MatPow):
  297. B_base, B_exp = B.args
  298. else:
  299. B_base, B_exp = B, S.One
  300. if A_base == B_base:
  301. new_exp = A_exp + B_exp
  302. new_args[-1] = MatPow(A_base, new_exp).doit(deep=False)
  303. continue
  304. elif not isinstance(B_base, MatrixBase):
  305. try:
  306. B_base_inv = B_base.inverse()
  307. except NonInvertibleMatrixError:
  308. B_base_inv = None
  309. if B_base_inv is not None and A_base == B_base_inv:
  310. new_exp = A_exp - B_exp
  311. new_args[-1] = MatPow(A_base, new_exp).doit(deep=False)
  312. continue
  313. new_args.append(B)
  314. return newmul(factor, *new_args)
  315. def combine_permutations(mul):
  316. """Refine products of permutation matrices as the products of cycles.
  317. """
  318. args = mul.args
  319. l = len(args)
  320. if l < 2:
  321. return mul
  322. result = [args[0]]
  323. for i in range(1, l):
  324. A = result[-1]
  325. B = args[i]
  326. if isinstance(A, PermutationMatrix) and \
  327. isinstance(B, PermutationMatrix):
  328. cycle_1 = A.args[0]
  329. cycle_2 = B.args[0]
  330. result[-1] = PermutationMatrix(cycle_1 * cycle_2)
  331. else:
  332. result.append(B)
  333. return MatMul(*result)
  334. def combine_one_matrices(mul):
  335. """
  336. Combine products of OneMatrix
  337. e.g. OneMatrix(2, 3) * OneMatrix(3, 4) -> 3 * OneMatrix(2, 4)
  338. """
  339. factor, args = mul.as_coeff_matrices()
  340. new_args = [args[0]]
  341. for B in args[1:]:
  342. A = new_args[-1]
  343. if not isinstance(A, OneMatrix) or not isinstance(B, OneMatrix):
  344. new_args.append(B)
  345. continue
  346. new_args.pop()
  347. new_args.append(OneMatrix(A.shape[0], B.shape[1]))
  348. factor *= A.shape[1]
  349. return newmul(factor, *new_args)
  350. def distribute_monom(mul):
  351. """
  352. Simplify MatMul expressions but distributing
  353. rational term to MatMul.
  354. e.g. 2*(A+B) -> 2*A + 2*B
  355. """
  356. args = mul.args
  357. if len(args) == 2:
  358. from .matadd import MatAdd
  359. if args[0].is_MatAdd and args[1].is_Rational:
  360. return MatAdd(*[MatMul(mat, args[1]).doit() for mat in args[0].args])
  361. if args[1].is_MatAdd and args[0].is_Rational:
  362. return MatAdd(*[MatMul(args[0], mat).doit() for mat in args[1].args])
  363. return mul
  364. rules = (
  365. distribute_monom, any_zeros, remove_ids, combine_one_matrices, combine_powers, unpack, rm_id(lambda x: x == 1),
  366. merge_explicit, factor_in_front, flatten, combine_permutations)
  367. canonicalize = exhaust(typed({MatMul: do_one(*rules)}))
  368. def only_squares(*matrices):
  369. """factor matrices only if they are square"""
  370. if matrices[0].rows != matrices[-1].cols:
  371. raise RuntimeError("Invalid matrices being multiplied")
  372. out = []
  373. start = 0
  374. for i, M in enumerate(matrices):
  375. if M.cols == matrices[start].rows:
  376. out.append(MatMul(*matrices[start:i+1]).doit())
  377. start = i+1
  378. return out
  379. def refine_MatMul(expr, assumptions):
  380. """
  381. >>> from sympy import MatrixSymbol, Q, assuming, refine
  382. >>> X = MatrixSymbol('X', 2, 2)
  383. >>> expr = X * X.T
  384. >>> print(expr)
  385. X*X.T
  386. >>> with assuming(Q.orthogonal(X)):
  387. ... print(refine(expr))
  388. I
  389. """
  390. newargs = []
  391. exprargs = []
  392. for args in expr.args:
  393. if args.is_Matrix:
  394. exprargs.append(args)
  395. else:
  396. newargs.append(args)
  397. last = exprargs[0]
  398. for arg in exprargs[1:]:
  399. if arg == last.T and ask(Q.orthogonal(arg), assumptions):
  400. last = Identity(arg.shape[0])
  401. elif arg == last.conjugate() and ask(Q.unitary(arg), assumptions):
  402. last = Identity(arg.shape[0])
  403. else:
  404. newargs.append(last)
  405. last = arg
  406. newargs.append(last)
  407. return MatMul(*newargs)
  408. handlers_dict['MatMul'] = refine_MatMul