matpow.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. from .matexpr import MatrixExpr
  2. from .special import Identity
  3. from sympy.core import S
  4. from sympy.core.expr import ExprBuilder
  5. from sympy.core.cache import cacheit
  6. from sympy.core.power import Pow
  7. from sympy.core.sympify import _sympify
  8. from sympy.matrices import MatrixBase
  9. from sympy.matrices.exceptions import NonSquareMatrixError
  10. class MatPow(MatrixExpr):
  11. def __new__(cls, base, exp, evaluate=False, **options):
  12. base = _sympify(base)
  13. if not base.is_Matrix:
  14. raise TypeError("MatPow base should be a matrix")
  15. if base.is_square is False:
  16. raise NonSquareMatrixError("Power of non-square matrix %s" % base)
  17. exp = _sympify(exp)
  18. obj = super().__new__(cls, base, exp)
  19. if evaluate:
  20. obj = obj.doit(deep=False)
  21. return obj
  22. @property
  23. def base(self):
  24. return self.args[0]
  25. @property
  26. def exp(self):
  27. return self.args[1]
  28. @property
  29. def shape(self):
  30. return self.base.shape
  31. @cacheit
  32. def _get_explicit_matrix(self):
  33. return self.base.as_explicit()**self.exp
  34. def _entry(self, i, j, **kwargs):
  35. from sympy.matrices.expressions import MatMul
  36. A = self.doit()
  37. if isinstance(A, MatPow):
  38. # We still have a MatPow, make an explicit MatMul out of it.
  39. if A.exp.is_Integer and A.exp.is_positive:
  40. A = MatMul(*[A.base for k in range(A.exp)])
  41. elif not self._is_shape_symbolic():
  42. return A._get_explicit_matrix()[i, j]
  43. else:
  44. # Leave the expression unevaluated:
  45. from sympy.matrices.expressions.matexpr import MatrixElement
  46. return MatrixElement(self, i, j)
  47. return A[i, j]
  48. def doit(self, **hints):
  49. if hints.get('deep', True):
  50. base, exp = (arg.doit(**hints) for arg in self.args)
  51. else:
  52. base, exp = self.args
  53. # combine all powers, e.g. (A ** 2) ** 3 -> A ** 6
  54. while isinstance(base, MatPow):
  55. exp *= base.args[1]
  56. base = base.args[0]
  57. if isinstance(base, MatrixBase):
  58. # Delegate
  59. return base ** exp
  60. # Handle simple cases so that _eval_power() in MatrixExpr sub-classes can ignore them
  61. if exp == S.One:
  62. return base
  63. if exp == S.Zero:
  64. return Identity(base.rows)
  65. if exp == S.NegativeOne:
  66. from sympy.matrices.expressions import Inverse
  67. return Inverse(base).doit(**hints)
  68. eval_power = getattr(base, '_eval_power', None)
  69. if eval_power is not None:
  70. return eval_power(exp)
  71. return MatPow(base, exp)
  72. def _eval_transpose(self):
  73. base, exp = self.args
  74. return MatPow(base.transpose(), exp)
  75. def _eval_adjoint(self):
  76. base, exp = self.args
  77. return MatPow(base.adjoint(), exp)
  78. def _eval_conjugate(self):
  79. base, exp = self.args
  80. return MatPow(base.conjugate(), exp)
  81. def _eval_derivative(self, x):
  82. return Pow._eval_derivative(self, x)
  83. def _eval_derivative_matrix_lines(self, x):
  84. from sympy.tensor.array.expressions.array_expressions import ArrayContraction
  85. from ...tensor.array.expressions.array_expressions import ArrayTensorProduct
  86. from .matmul import MatMul
  87. from .inverse import Inverse
  88. exp = self.exp
  89. if self.base.shape == (1, 1) and not exp.has(x):
  90. lr = self.base._eval_derivative_matrix_lines(x)
  91. for i in lr:
  92. subexpr = ExprBuilder(
  93. ArrayContraction,
  94. [
  95. ExprBuilder(
  96. ArrayTensorProduct,
  97. [
  98. Identity(1),
  99. i._lines[0],
  100. exp*self.base**(exp-1),
  101. i._lines[1],
  102. Identity(1),
  103. ]
  104. ),
  105. (0, 3, 4), (5, 7, 8)
  106. ],
  107. validator=ArrayContraction._validate
  108. )
  109. i._first_pointer_parent = subexpr.args[0].args
  110. i._first_pointer_index = 0
  111. i._second_pointer_parent = subexpr.args[0].args
  112. i._second_pointer_index = 4
  113. i._lines = [subexpr]
  114. return lr
  115. if (exp > 0) == True:
  116. newexpr = MatMul.fromiter([self.base for i in range(exp)])
  117. elif (exp == -1) == True:
  118. return Inverse(self.base)._eval_derivative_matrix_lines(x)
  119. elif (exp < 0) == True:
  120. newexpr = MatMul.fromiter([Inverse(self.base) for i in range(-exp)])
  121. elif (exp == 0) == True:
  122. return self.doit()._eval_derivative_matrix_lines(x)
  123. else:
  124. raise NotImplementedError("cannot evaluate %s derived by %s" % (self, x))
  125. return newexpr._eval_derivative_matrix_lines(x)
  126. def _eval_inverse(self):
  127. return MatPow(self.base, -self.exp)