matexpr.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888
  1. from __future__ import annotations
  2. from functools import wraps
  3. from sympy.core import S, Integer, Basic, Mul, Add
  4. from sympy.core.assumptions import check_assumptions
  5. from sympy.core.decorators import call_highest_priority
  6. from sympy.core.expr import Expr, ExprBuilder
  7. from sympy.core.logic import FuzzyBool
  8. from sympy.core.symbol import Str, Dummy, symbols, Symbol
  9. from sympy.core.sympify import SympifyError, _sympify
  10. from sympy.external.gmpy import SYMPY_INTS
  11. from sympy.functions import conjugate, adjoint
  12. from sympy.functions.special.tensor_functions import KroneckerDelta
  13. from sympy.matrices.exceptions import NonSquareMatrixError
  14. from sympy.matrices.kind import MatrixKind
  15. from sympy.matrices.matrixbase import MatrixBase
  16. from sympy.multipledispatch import dispatch
  17. from sympy.utilities.misc import filldedent
  18. def _sympifyit(arg, retval=None):
  19. # This version of _sympifyit sympifies MutableMatrix objects
  20. def deco(func):
  21. @wraps(func)
  22. def __sympifyit_wrapper(a, b):
  23. try:
  24. b = _sympify(b)
  25. return func(a, b)
  26. except SympifyError:
  27. return retval
  28. return __sympifyit_wrapper
  29. return deco
  30. class MatrixExpr(Expr):
  31. """Superclass for Matrix Expressions
  32. MatrixExprs represent abstract matrices, linear transformations represented
  33. within a particular basis.
  34. Examples
  35. ========
  36. >>> from sympy import MatrixSymbol
  37. >>> A = MatrixSymbol('A', 3, 3)
  38. >>> y = MatrixSymbol('y', 3, 1)
  39. >>> x = (A.T*A).I * A * y
  40. See Also
  41. ========
  42. MatrixSymbol, MatAdd, MatMul, Transpose, Inverse
  43. """
  44. __slots__: tuple[str, ...] = ()
  45. # Should not be considered iterable by the
  46. # sympy.utilities.iterables.iterable function. Subclass that actually are
  47. # iterable (i.e., explicit matrices) should set this to True.
  48. _iterable = False
  49. _op_priority = 11.0
  50. is_Matrix: bool = True
  51. is_MatrixExpr: bool = True
  52. is_Identity: FuzzyBool = None
  53. is_Inverse = False
  54. is_Transpose = False
  55. is_ZeroMatrix = False
  56. is_MatAdd = False
  57. is_MatMul = False
  58. is_commutative = False
  59. is_number = False
  60. is_symbol = False
  61. is_scalar = False
  62. kind: MatrixKind = MatrixKind()
  63. def __new__(cls, *args, **kwargs):
  64. args = map(_sympify, args)
  65. return Basic.__new__(cls, *args, **kwargs)
  66. # The following is adapted from the core Expr object
  67. @property
  68. def shape(self) -> tuple[Expr, Expr]:
  69. raise NotImplementedError
  70. @property
  71. def _add_handler(self):
  72. return MatAdd
  73. @property
  74. def _mul_handler(self):
  75. return MatMul
  76. def __neg__(self):
  77. return MatMul(S.NegativeOne, self).doit()
  78. def __abs__(self):
  79. raise NotImplementedError
  80. @_sympifyit('other', NotImplemented)
  81. @call_highest_priority('__radd__')
  82. def __add__(self, other):
  83. return MatAdd(self, other).doit()
  84. @_sympifyit('other', NotImplemented)
  85. @call_highest_priority('__add__')
  86. def __radd__(self, other):
  87. return MatAdd(other, self).doit()
  88. @_sympifyit('other', NotImplemented)
  89. @call_highest_priority('__rsub__')
  90. def __sub__(self, other):
  91. return MatAdd(self, -other).doit()
  92. @_sympifyit('other', NotImplemented)
  93. @call_highest_priority('__sub__')
  94. def __rsub__(self, other):
  95. return MatAdd(other, -self).doit()
  96. @_sympifyit('other', NotImplemented)
  97. @call_highest_priority('__rmul__')
  98. def __mul__(self, other):
  99. return MatMul(self, other).doit()
  100. @_sympifyit('other', NotImplemented)
  101. @call_highest_priority('__rmul__')
  102. def __matmul__(self, other):
  103. return MatMul(self, other).doit()
  104. @_sympifyit('other', NotImplemented)
  105. @call_highest_priority('__mul__')
  106. def __rmul__(self, other):
  107. return MatMul(other, self).doit()
  108. @_sympifyit('other', NotImplemented)
  109. @call_highest_priority('__mul__')
  110. def __rmatmul__(self, other):
  111. return MatMul(other, self).doit()
  112. @_sympifyit('other', NotImplemented)
  113. @call_highest_priority('__rpow__')
  114. def __pow__(self, other):
  115. return MatPow(self, other).doit()
  116. @_sympifyit('other', NotImplemented)
  117. @call_highest_priority('__pow__')
  118. def __rpow__(self, other):
  119. raise NotImplementedError("Matrix Power not defined")
  120. @_sympifyit('other', NotImplemented)
  121. @call_highest_priority('__rtruediv__')
  122. def __truediv__(self, other):
  123. return self * other**S.NegativeOne
  124. @_sympifyit('other', NotImplemented)
  125. @call_highest_priority('__truediv__')
  126. def __rtruediv__(self, other):
  127. raise NotImplementedError()
  128. #return MatMul(other, Pow(self, S.NegativeOne))
  129. @property
  130. def rows(self):
  131. return self.shape[0]
  132. @property
  133. def cols(self):
  134. return self.shape[1]
  135. @property
  136. def is_square(self) -> bool | None:
  137. rows, cols = self.shape
  138. if isinstance(rows, Integer) and isinstance(cols, Integer):
  139. return rows == cols
  140. if rows == cols:
  141. return True
  142. return None
  143. def _eval_conjugate(self):
  144. from sympy.matrices.expressions.adjoint import Adjoint
  145. return Adjoint(Transpose(self))
  146. def as_real_imag(self, deep=True, **hints):
  147. return self._eval_as_real_imag()
  148. def _eval_as_real_imag(self):
  149. real = S.Half * (self + self._eval_conjugate())
  150. im = (self - self._eval_conjugate())/(2*S.ImaginaryUnit)
  151. return (real, im)
  152. def _eval_inverse(self):
  153. return Inverse(self)
  154. def _eval_determinant(self):
  155. return Determinant(self)
  156. def _eval_transpose(self):
  157. return Transpose(self)
  158. def _eval_trace(self):
  159. return None
  160. def _eval_power(self, exp):
  161. """
  162. Override this in sub-classes to implement simplification of powers. The cases where the exponent
  163. is -1, 0, 1 are already covered in MatPow.doit(), so implementations can exclude these cases.
  164. """
  165. return MatPow(self, exp)
  166. def _eval_simplify(self, **kwargs):
  167. if self.is_Atom:
  168. return self
  169. else:
  170. from sympy.simplify import simplify
  171. return self.func(*[simplify(x, **kwargs) for x in self.args])
  172. def _eval_adjoint(self):
  173. from sympy.matrices.expressions.adjoint import Adjoint
  174. return Adjoint(self)
  175. def _eval_derivative_n_times(self, x, n):
  176. return Basic._eval_derivative_n_times(self, x, n)
  177. def _eval_derivative(self, x):
  178. # `x` is a scalar:
  179. if self.has(x):
  180. # See if there are other methods using it:
  181. return super()._eval_derivative(x)
  182. else:
  183. return ZeroMatrix(*self.shape)
  184. @classmethod
  185. def _check_dim(cls, dim):
  186. """Helper function to check invalid matrix dimensions"""
  187. ok = not dim.is_Float and check_assumptions(
  188. dim, integer=True, nonnegative=True)
  189. if ok is False:
  190. raise ValueError(
  191. "The dimension specification {} should be "
  192. "a nonnegative integer.".format(dim))
  193. def _entry(self, i, j, **kwargs):
  194. raise NotImplementedError(
  195. "Indexing not implemented for %s" % self.__class__.__name__)
  196. def adjoint(self):
  197. return adjoint(self)
  198. def as_coeff_Mul(self, rational=False):
  199. """Efficiently extract the coefficient of a product."""
  200. return S.One, self
  201. def conjugate(self):
  202. return conjugate(self)
  203. def transpose(self):
  204. from sympy.matrices.expressions.transpose import transpose
  205. return transpose(self)
  206. @property
  207. def T(self):
  208. '''Matrix transposition'''
  209. return self.transpose()
  210. def inverse(self):
  211. if self.is_square is False:
  212. raise NonSquareMatrixError('Inverse of non-square matrix')
  213. return self._eval_inverse()
  214. def inv(self):
  215. return self.inverse()
  216. def det(self):
  217. from sympy.matrices.expressions.determinant import det
  218. return det(self)
  219. @property
  220. def I(self):
  221. return self.inverse()
  222. def valid_index(self, i, j):
  223. def is_valid(idx):
  224. return isinstance(idx, (int, Integer, Symbol, Expr))
  225. return (is_valid(i) and is_valid(j) and
  226. (self.rows is None or
  227. (i >= -self.rows) != False and (i < self.rows) != False) and
  228. (j >= -self.cols) != False and (j < self.cols) != False)
  229. def __getitem__(self, key):
  230. if not isinstance(key, tuple) and isinstance(key, slice):
  231. from sympy.matrices.expressions.slice import MatrixSlice
  232. return MatrixSlice(self, key, (0, None, 1))
  233. if isinstance(key, tuple) and len(key) == 2:
  234. i, j = key
  235. if isinstance(i, slice) or isinstance(j, slice):
  236. from sympy.matrices.expressions.slice import MatrixSlice
  237. return MatrixSlice(self, i, j)
  238. i, j = _sympify(i), _sympify(j)
  239. if self.valid_index(i, j) != False:
  240. return self._entry(i, j)
  241. else:
  242. raise IndexError("Invalid indices (%s, %s)" % (i, j))
  243. elif isinstance(key, (SYMPY_INTS, Integer)):
  244. # row-wise decomposition of matrix
  245. rows, cols = self.shape
  246. # allow single indexing if number of columns is known
  247. if not isinstance(cols, Integer):
  248. raise IndexError(filldedent('''
  249. Single indexing is only supported when the number
  250. of columns is known.'''))
  251. key = _sympify(key)
  252. i = key // cols
  253. j = key % cols
  254. if self.valid_index(i, j) != False:
  255. return self._entry(i, j)
  256. else:
  257. raise IndexError("Invalid index %s" % key)
  258. elif isinstance(key, (Symbol, Expr)):
  259. raise IndexError(filldedent('''
  260. Only integers may be used when addressing the matrix
  261. with a single index.'''))
  262. raise IndexError("Invalid index, wanted %s[i,j]" % self)
  263. def _is_shape_symbolic(self) -> bool:
  264. return (not isinstance(self.rows, (SYMPY_INTS, Integer))
  265. or not isinstance(self.cols, (SYMPY_INTS, Integer)))
  266. def as_explicit(self):
  267. """
  268. Returns a dense Matrix with elements represented explicitly
  269. Returns an object of type ImmutableDenseMatrix.
  270. Examples
  271. ========
  272. >>> from sympy import Identity
  273. >>> I = Identity(3)
  274. >>> I
  275. I
  276. >>> I.as_explicit()
  277. Matrix([
  278. [1, 0, 0],
  279. [0, 1, 0],
  280. [0, 0, 1]])
  281. See Also
  282. ========
  283. as_mutable: returns mutable Matrix type
  284. """
  285. if self._is_shape_symbolic():
  286. raise ValueError(
  287. 'Matrix with symbolic shape '
  288. 'cannot be represented explicitly.')
  289. from sympy.matrices.immutable import ImmutableDenseMatrix
  290. return ImmutableDenseMatrix([[self[i, j]
  291. for j in range(self.cols)]
  292. for i in range(self.rows)])
  293. def as_mutable(self):
  294. """
  295. Returns a dense, mutable matrix with elements represented explicitly
  296. Examples
  297. ========
  298. >>> from sympy import Identity
  299. >>> I = Identity(3)
  300. >>> I
  301. I
  302. >>> I.shape
  303. (3, 3)
  304. >>> I.as_mutable()
  305. Matrix([
  306. [1, 0, 0],
  307. [0, 1, 0],
  308. [0, 0, 1]])
  309. See Also
  310. ========
  311. as_explicit: returns ImmutableDenseMatrix
  312. """
  313. return self.as_explicit().as_mutable()
  314. def __array__(self, dtype=object, copy=None):
  315. if copy is not None and not copy:
  316. raise TypeError("Cannot implement copy=False when converting Matrix to ndarray")
  317. from numpy import empty
  318. a = empty(self.shape, dtype=object)
  319. for i in range(self.rows):
  320. for j in range(self.cols):
  321. a[i, j] = self[i, j]
  322. return a
  323. def equals(self, other):
  324. """
  325. Test elementwise equality between matrices, potentially of different
  326. types
  327. >>> from sympy import Identity, eye
  328. >>> Identity(3).equals(eye(3))
  329. True
  330. """
  331. return self.as_explicit().equals(other)
  332. def canonicalize(self):
  333. return self
  334. def as_coeff_mmul(self):
  335. return S.One, MatMul(self)
  336. @staticmethod
  337. def from_index_summation(expr, first_index=None, last_index=None, dimensions=None):
  338. r"""
  339. Parse expression of matrices with explicitly summed indices into a
  340. matrix expression without indices, if possible.
  341. This transformation expressed in mathematical notation:
  342. `\sum_{j=0}^{N-1} A_{i,j} B_{j,k} \Longrightarrow \mathbf{A}\cdot \mathbf{B}`
  343. Optional parameter ``first_index``: specify which free index to use as
  344. the index starting the expression.
  345. Examples
  346. ========
  347. >>> from sympy import MatrixSymbol, MatrixExpr, Sum
  348. >>> from sympy.abc import i, j, k, l, N
  349. >>> A = MatrixSymbol("A", N, N)
  350. >>> B = MatrixSymbol("B", N, N)
  351. >>> expr = Sum(A[i, j]*B[j, k], (j, 0, N-1))
  352. >>> MatrixExpr.from_index_summation(expr)
  353. A*B
  354. Transposition is detected:
  355. >>> expr = Sum(A[j, i]*B[j, k], (j, 0, N-1))
  356. >>> MatrixExpr.from_index_summation(expr)
  357. A.T*B
  358. Detect the trace:
  359. >>> expr = Sum(A[i, i], (i, 0, N-1))
  360. >>> MatrixExpr.from_index_summation(expr)
  361. Trace(A)
  362. More complicated expressions:
  363. >>> expr = Sum(A[i, j]*B[k, j]*A[l, k], (j, 0, N-1), (k, 0, N-1))
  364. >>> MatrixExpr.from_index_summation(expr)
  365. A*B.T*A.T
  366. """
  367. from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array
  368. from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix
  369. first_indices = []
  370. if first_index is not None:
  371. first_indices.append(first_index)
  372. if last_index is not None:
  373. first_indices.append(last_index)
  374. arr = convert_indexed_to_array(expr, first_indices=first_indices)
  375. return convert_array_to_matrix(arr)
  376. def applyfunc(self, func):
  377. from .applyfunc import ElementwiseApplyFunction
  378. return ElementwiseApplyFunction(func, self)
  379. @dispatch(MatrixExpr, Expr)
  380. def _eval_is_eq(lhs, rhs): # noqa:F811
  381. return False
  382. @dispatch(MatrixExpr, MatrixExpr) # type: ignore
  383. def _eval_is_eq(lhs, rhs): # noqa:F811
  384. if lhs.shape != rhs.shape:
  385. return False
  386. if (lhs - rhs).is_ZeroMatrix:
  387. return True
  388. def get_postprocessor(cls):
  389. def _postprocessor(expr):
  390. # To avoid circular imports, we can't have MatMul/MatAdd on the top level
  391. mat_class = {Mul: MatMul, Add: MatAdd}[cls]
  392. nonmatrices = []
  393. matrices = []
  394. for term in expr.args:
  395. if isinstance(term, MatrixExpr):
  396. matrices.append(term)
  397. else:
  398. nonmatrices.append(term)
  399. if not matrices:
  400. return cls._from_args(nonmatrices)
  401. if nonmatrices:
  402. if cls == Mul:
  403. for i in range(len(matrices)):
  404. if not matrices[i].is_MatrixExpr:
  405. # If one of the matrices explicit, absorb the scalar into it
  406. # (doit will combine all explicit matrices into one, so it
  407. # doesn't matter which)
  408. matrices[i] = matrices[i].__mul__(cls._from_args(nonmatrices))
  409. nonmatrices = []
  410. break
  411. else:
  412. # Maintain the ability to create Add(scalar, matrix) without
  413. # raising an exception. That way different algorithms can
  414. # replace matrix expressions with non-commutative symbols to
  415. # manipulate them like non-commutative scalars.
  416. return cls._from_args(nonmatrices + [mat_class(*matrices).doit(deep=False)])
  417. if mat_class == MatAdd:
  418. return mat_class(*matrices).doit(deep=False)
  419. return mat_class(cls._from_args(nonmatrices), *matrices).doit(deep=False)
  420. return _postprocessor
  421. Basic._constructor_postprocessor_mapping[MatrixExpr] = {
  422. "Mul": [get_postprocessor(Mul)],
  423. "Add": [get_postprocessor(Add)],
  424. }
  425. def _matrix_derivative(expr, x, old_algorithm=False):
  426. if isinstance(expr, MatrixBase) or isinstance(x, MatrixBase):
  427. # Do not use array expressions for explicit matrices:
  428. old_algorithm = True
  429. if old_algorithm:
  430. return _matrix_derivative_old_algorithm(expr, x)
  431. from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
  432. from sympy.tensor.array.expressions.arrayexpr_derivatives import array_derive
  433. from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix
  434. array_expr = convert_matrix_to_array(expr)
  435. diff_array_expr = array_derive(array_expr, x)
  436. diff_matrix_expr = convert_array_to_matrix(diff_array_expr)
  437. return diff_matrix_expr
  438. def _matrix_derivative_old_algorithm(expr, x):
  439. from sympy.tensor.array.array_derivatives import ArrayDerivative
  440. lines = expr._eval_derivative_matrix_lines(x)
  441. parts = [i.build() for i in lines]
  442. from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix
  443. parts = [[convert_array_to_matrix(j) for j in i] for i in parts]
  444. def _get_shape(elem):
  445. if isinstance(elem, MatrixExpr):
  446. return elem.shape
  447. return 1, 1
  448. def get_rank(parts):
  449. return sum(j not in (1, None) for i in parts for j in _get_shape(i))
  450. ranks = [get_rank(i) for i in parts]
  451. rank = ranks[0]
  452. def contract_one_dims(parts):
  453. if len(parts) == 1:
  454. return parts[0]
  455. else:
  456. p1, p2 = parts[:2]
  457. if p2.is_Matrix:
  458. p2 = p2.T
  459. if p1 == Identity(1):
  460. pbase = p2
  461. elif p2 == Identity(1):
  462. pbase = p1
  463. else:
  464. pbase = p1*p2
  465. if len(parts) == 2:
  466. return pbase
  467. else: # len(parts) > 2
  468. if pbase.is_Matrix:
  469. raise ValueError("")
  470. return pbase*Mul.fromiter(parts[2:])
  471. if rank <= 2:
  472. return Add.fromiter([contract_one_dims(i) for i in parts])
  473. return ArrayDerivative(expr, x)
  474. class MatrixElement(Expr):
  475. parent = property(lambda self: self.args[0])
  476. i = property(lambda self: self.args[1])
  477. j = property(lambda self: self.args[2])
  478. _diff_wrt = True
  479. is_symbol = True
  480. is_commutative = True
  481. def __new__(cls, name, n, m):
  482. n, m = map(_sympify, (n, m))
  483. if isinstance(name, str):
  484. name = Symbol(name)
  485. else:
  486. if isinstance(name, MatrixBase):
  487. if n.is_Integer and m.is_Integer:
  488. return name[n, m]
  489. name = _sympify(name) # change mutable into immutable
  490. else:
  491. name = _sympify(name)
  492. if not isinstance(name.kind, MatrixKind):
  493. raise TypeError("First argument of MatrixElement should be a matrix")
  494. if not getattr(name, 'valid_index', lambda n, m: True)(n, m):
  495. raise IndexError('indices out of range')
  496. obj = Expr.__new__(cls, name, n, m)
  497. return obj
  498. @property
  499. def symbol(self):
  500. return self.args[0]
  501. def doit(self, **hints):
  502. deep = hints.get('deep', True)
  503. if deep:
  504. args = [arg.doit(**hints) for arg in self.args]
  505. else:
  506. args = self.args
  507. return args[0][args[1], args[2]]
  508. @property
  509. def indices(self):
  510. return self.args[1:]
  511. def _eval_derivative(self, v):
  512. if not isinstance(v, MatrixElement):
  513. return self.parent.diff(v)[self.i, self.j]
  514. M = self.args[0]
  515. m, n = self.parent.shape
  516. if M == v.args[0]:
  517. return KroneckerDelta(self.args[1], v.args[1], (0, m-1)) * \
  518. KroneckerDelta(self.args[2], v.args[2], (0, n-1))
  519. if isinstance(M, Inverse):
  520. from sympy.concrete.summations import Sum
  521. i, j = self.args[1:]
  522. i1, i2 = symbols("z1, z2", cls=Dummy)
  523. Y = M.args[0]
  524. r1, r2 = Y.shape
  525. return -Sum(M[i, i1]*Y[i1, i2].diff(v)*M[i2, j], (i1, 0, r1-1), (i2, 0, r2-1))
  526. if self.has(v.args[0]):
  527. return None
  528. return S.Zero
  529. class MatrixSymbol(MatrixExpr):
  530. """Symbolic representation of a Matrix object
  531. Creates a SymPy Symbol to represent a Matrix. This matrix has a shape and
  532. can be included in Matrix Expressions
  533. Examples
  534. ========
  535. >>> from sympy import MatrixSymbol, Identity
  536. >>> A = MatrixSymbol('A', 3, 4) # A 3 by 4 Matrix
  537. >>> B = MatrixSymbol('B', 4, 3) # A 4 by 3 Matrix
  538. >>> A.shape
  539. (3, 4)
  540. >>> 2*A*B + Identity(3)
  541. I + 2*A*B
  542. """
  543. is_commutative = False
  544. is_symbol = True
  545. _diff_wrt = True
  546. def __new__(cls, name, n, m):
  547. n, m = _sympify(n), _sympify(m)
  548. cls._check_dim(m)
  549. cls._check_dim(n)
  550. if isinstance(name, str):
  551. name = Str(name)
  552. obj = Basic.__new__(cls, name, n, m)
  553. return obj
  554. @property
  555. def shape(self):
  556. return self.args[1], self.args[2]
  557. @property
  558. def name(self):
  559. return self.args[0].name
  560. def _entry(self, i, j, **kwargs):
  561. return MatrixElement(self, i, j)
  562. @property
  563. def free_symbols(self):
  564. return {self}
  565. def _eval_simplify(self, **kwargs):
  566. return self
  567. def _eval_derivative(self, x):
  568. # x is a scalar:
  569. return ZeroMatrix(self.shape[0], self.shape[1])
  570. def _eval_derivative_matrix_lines(self, x):
  571. if self != x:
  572. first = ZeroMatrix(x.shape[0], self.shape[0]) if self.shape[0] != 1 else S.Zero
  573. second = ZeroMatrix(x.shape[1], self.shape[1]) if self.shape[1] != 1 else S.Zero
  574. return [_LeftRightArgs(
  575. [first, second],
  576. )]
  577. else:
  578. first = Identity(self.shape[0]) if self.shape[0] != 1 else S.One
  579. second = Identity(self.shape[1]) if self.shape[1] != 1 else S.One
  580. return [_LeftRightArgs(
  581. [first, second],
  582. )]
  583. def matrix_symbols(expr):
  584. return [sym for sym in expr.free_symbols if sym.is_Matrix]
  585. class _LeftRightArgs:
  586. r"""
  587. Helper class to compute matrix derivatives.
  588. The logic: when an expression is derived by a matrix `X_{mn}`, two lines of
  589. matrix multiplications are created: the one contracted to `m` (first line),
  590. and the one contracted to `n` (second line).
  591. Transposition flips the side by which new matrices are connected to the
  592. lines.
  593. The trace connects the end of the two lines.
  594. """
  595. def __init__(self, lines, higher=S.One):
  596. self._lines = list(lines)
  597. self._first_pointer_parent = self._lines
  598. self._first_pointer_index = 0
  599. self._first_line_index = 0
  600. self._second_pointer_parent = self._lines
  601. self._second_pointer_index = 1
  602. self._second_line_index = 1
  603. self.higher = higher
  604. @property
  605. def first_pointer(self):
  606. return self._first_pointer_parent[self._first_pointer_index]
  607. @first_pointer.setter
  608. def first_pointer(self, value):
  609. self._first_pointer_parent[self._first_pointer_index] = value
  610. @property
  611. def second_pointer(self):
  612. return self._second_pointer_parent[self._second_pointer_index]
  613. @second_pointer.setter
  614. def second_pointer(self, value):
  615. self._second_pointer_parent[self._second_pointer_index] = value
  616. def __repr__(self):
  617. built = [self._build(i) for i in self._lines]
  618. return "_LeftRightArgs(lines=%s, higher=%s)" % (
  619. built,
  620. self.higher,
  621. )
  622. def transpose(self):
  623. self._first_pointer_parent, self._second_pointer_parent = self._second_pointer_parent, self._first_pointer_parent
  624. self._first_pointer_index, self._second_pointer_index = self._second_pointer_index, self._first_pointer_index
  625. self._first_line_index, self._second_line_index = self._second_line_index, self._first_line_index
  626. return self
  627. @staticmethod
  628. def _build(expr):
  629. if isinstance(expr, ExprBuilder):
  630. return expr.build()
  631. if isinstance(expr, list):
  632. if len(expr) == 1:
  633. return expr[0]
  634. else:
  635. return expr[0](*[_LeftRightArgs._build(i) for i in expr[1]])
  636. else:
  637. return expr
  638. def build(self):
  639. data = [self._build(i) for i in self._lines]
  640. if self.higher != 1:
  641. data += [self._build(self.higher)]
  642. data = list(data)
  643. return data
  644. def matrix_form(self):
  645. if self.first != 1 and self.higher != 1:
  646. raise ValueError("higher dimensional array cannot be represented")
  647. def _get_shape(elem):
  648. if isinstance(elem, MatrixExpr):
  649. return elem.shape
  650. return (None, None)
  651. if _get_shape(self.first)[1] != _get_shape(self.second)[1]:
  652. # Remove one-dimensional identity matrices:
  653. # (this is needed by `a.diff(a)` where `a` is a vector)
  654. if _get_shape(self.second) == (1, 1):
  655. return self.first*self.second[0, 0]
  656. if _get_shape(self.first) == (1, 1):
  657. return self.first[1, 1]*self.second.T
  658. raise ValueError("incompatible shapes")
  659. if self.first != 1:
  660. return self.first*self.second.T
  661. else:
  662. return self.higher
  663. def rank(self):
  664. """
  665. Number of dimensions different from trivial (warning: not related to
  666. matrix rank).
  667. """
  668. rank = 0
  669. if self.first != 1:
  670. rank += sum(i != 1 for i in self.first.shape)
  671. if self.second != 1:
  672. rank += sum(i != 1 for i in self.second.shape)
  673. if self.higher != 1:
  674. rank += 2
  675. return rank
  676. def _multiply_pointer(self, pointer, other):
  677. from ...tensor.array.expressions.array_expressions import ArrayTensorProduct
  678. from ...tensor.array.expressions.array_expressions import ArrayContraction
  679. subexpr = ExprBuilder(
  680. ArrayContraction,
  681. [
  682. ExprBuilder(
  683. ArrayTensorProduct,
  684. [
  685. pointer,
  686. other
  687. ]
  688. ),
  689. (1, 2)
  690. ],
  691. validator=ArrayContraction._validate
  692. )
  693. return subexpr
  694. def append_first(self, other):
  695. self.first_pointer *= other
  696. def append_second(self, other):
  697. self.second_pointer *= other
  698. def _make_matrix(x):
  699. from sympy.matrices.immutable import ImmutableDenseMatrix
  700. if isinstance(x, MatrixExpr):
  701. return x
  702. return ImmutableDenseMatrix([[x]])
  703. from .matmul import MatMul
  704. from .matadd import MatAdd
  705. from .matpow import MatPow
  706. from .transpose import Transpose
  707. from .inverse import Inverse
  708. from .special import ZeroMatrix, Identity
  709. from .determinant import Determinant