transpose.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from sympy.core.basic import Basic
  2. from sympy.matrices.expressions.matexpr import MatrixExpr
  3. class Transpose(MatrixExpr):
  4. """
  5. The transpose of a matrix expression.
  6. This is a symbolic object that simply stores its argument without
  7. evaluating it. To actually compute the transpose, use the ``transpose()``
  8. function, or the ``.T`` attribute of matrices.
  9. Examples
  10. ========
  11. >>> from sympy import MatrixSymbol, Transpose, transpose
  12. >>> A = MatrixSymbol('A', 3, 5)
  13. >>> B = MatrixSymbol('B', 5, 3)
  14. >>> Transpose(A)
  15. A.T
  16. >>> A.T == transpose(A) == Transpose(A)
  17. True
  18. >>> Transpose(A*B)
  19. (A*B).T
  20. >>> transpose(A*B)
  21. B.T*A.T
  22. """
  23. is_Transpose = True
  24. def doit(self, **hints):
  25. arg = self.arg
  26. if hints.get('deep', True) and isinstance(arg, Basic):
  27. arg = arg.doit(**hints)
  28. _eval_transpose = getattr(arg, '_eval_transpose', None)
  29. if _eval_transpose is not None:
  30. result = _eval_transpose()
  31. return result if result is not None else Transpose(arg)
  32. else:
  33. return Transpose(arg)
  34. @property
  35. def arg(self):
  36. return self.args[0]
  37. @property
  38. def shape(self):
  39. return self.arg.shape[::-1]
  40. def _entry(self, i, j, expand=False, **kwargs):
  41. return self.arg._entry(j, i, expand=expand, **kwargs)
  42. def _eval_adjoint(self):
  43. return self.arg.conjugate()
  44. def _eval_conjugate(self):
  45. return self.arg.adjoint()
  46. def _eval_transpose(self):
  47. return self.arg
  48. def _eval_trace(self):
  49. from .trace import Trace
  50. return Trace(self.arg) # Trace(X.T) => Trace(X)
  51. def _eval_determinant(self):
  52. from sympy.matrices.expressions.determinant import det
  53. return det(self.arg)
  54. def _eval_derivative(self, x):
  55. # x is a scalar:
  56. return self.arg._eval_derivative(x)
  57. def _eval_derivative_matrix_lines(self, x):
  58. lines = self.args[0]._eval_derivative_matrix_lines(x)
  59. return [i.transpose() for i in lines]
  60. def transpose(expr):
  61. """Matrix transpose"""
  62. return Transpose(expr).doit(deep=False)
  63. from sympy.assumptions.ask import ask, Q
  64. from sympy.assumptions.refine import handlers_dict
  65. def refine_Transpose(expr, assumptions):
  66. """
  67. >>> from sympy import MatrixSymbol, Q, assuming, refine
  68. >>> X = MatrixSymbol('X', 2, 2)
  69. >>> X.T
  70. X.T
  71. >>> with assuming(Q.symmetric(X)):
  72. ... print(refine(X.T))
  73. X
  74. """
  75. if ask(Q.symmetric(expr), assumptions):
  76. return expr.arg
  77. return expr
  78. handlers_dict['Transpose'] = refine_Transpose