immutable.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. from mpmath.matrices.matrices import _matrix
  2. from sympy.core import Basic, Dict, Tuple
  3. from sympy.core.numbers import Integer
  4. from sympy.core.cache import cacheit
  5. from sympy.core.sympify import _sympy_converter as sympify_converter, _sympify
  6. from sympy.matrices.dense import DenseMatrix
  7. from sympy.matrices.expressions import MatrixExpr
  8. from sympy.matrices.matrixbase import MatrixBase
  9. from sympy.matrices.repmatrix import RepMatrix
  10. from sympy.matrices.sparse import SparseRepMatrix
  11. from sympy.multipledispatch import dispatch
  12. def sympify_matrix(arg):
  13. return arg.as_immutable()
  14. sympify_converter[MatrixBase] = sympify_matrix
  15. def sympify_mpmath_matrix(arg):
  16. mat = [_sympify(x) for x in arg]
  17. return ImmutableDenseMatrix(arg.rows, arg.cols, mat)
  18. sympify_converter[_matrix] = sympify_mpmath_matrix
  19. class ImmutableRepMatrix(RepMatrix, MatrixExpr): # type: ignore
  20. """Immutable matrix based on RepMatrix
  21. Uses DomainMAtrix as the internal representation.
  22. """
  23. #
  24. # This is a subclass of RepMatrix that adds/overrides some methods to make
  25. # the instances Basic and immutable. ImmutableRepMatrix is a superclass for
  26. # both ImmutableDenseMatrix and ImmutableSparseMatrix.
  27. #
  28. def __new__(cls, *args, **kwargs):
  29. return cls._new(*args, **kwargs)
  30. __hash__ = MatrixExpr.__hash__
  31. def copy(self):
  32. return self
  33. @property
  34. def cols(self):
  35. return self._cols
  36. @property
  37. def rows(self):
  38. return self._rows
  39. @property
  40. def shape(self):
  41. return self._rows, self._cols
  42. def as_immutable(self):
  43. return self
  44. def _entry(self, i, j, **kwargs):
  45. return self[i, j]
  46. def __setitem__(self, *args):
  47. raise TypeError("Cannot set values of {}".format(self.__class__))
  48. def is_diagonalizable(self, reals_only=False, **kwargs):
  49. return super().is_diagonalizable(
  50. reals_only=reals_only, **kwargs)
  51. is_diagonalizable.__doc__ = SparseRepMatrix.is_diagonalizable.__doc__
  52. is_diagonalizable = cacheit(is_diagonalizable)
  53. def analytic_func(self, f, x):
  54. return self.as_mutable().analytic_func(f, x).as_immutable()
  55. class ImmutableDenseMatrix(DenseMatrix, ImmutableRepMatrix): # type: ignore
  56. """Create an immutable version of a matrix.
  57. Examples
  58. ========
  59. >>> from sympy import eye, ImmutableMatrix
  60. >>> ImmutableMatrix(eye(3))
  61. Matrix([
  62. [1, 0, 0],
  63. [0, 1, 0],
  64. [0, 0, 1]])
  65. >>> _[0, 0] = 42
  66. Traceback (most recent call last):
  67. ...
  68. TypeError: Cannot set values of ImmutableDenseMatrix
  69. """
  70. # MatrixExpr is set as NotIterable, but we want explicit matrices to be
  71. # iterable
  72. _iterable = True
  73. _class_priority = 8
  74. _op_priority = 10.001
  75. @classmethod
  76. def _new(cls, *args, **kwargs):
  77. if len(args) == 1 and isinstance(args[0], ImmutableDenseMatrix):
  78. return args[0]
  79. if kwargs.get('copy', True) is False:
  80. if len(args) != 3:
  81. raise TypeError("'copy=False' requires a matrix be initialized as rows,cols,[list]")
  82. rows, cols, flat_list = args
  83. else:
  84. rows, cols, flat_list = cls._handle_creation_inputs(*args, **kwargs)
  85. flat_list = list(flat_list) # create a shallow copy
  86. rep = cls._flat_list_to_DomainMatrix(rows, cols, flat_list)
  87. return cls._fromrep(rep)
  88. @classmethod
  89. def _fromrep(cls, rep):
  90. rows, cols = rep.shape
  91. flat_list = rep.to_sympy().to_list_flat()
  92. obj = Basic.__new__(cls,
  93. Integer(rows),
  94. Integer(cols),
  95. Tuple(*flat_list, sympify=False))
  96. obj._rows = rows
  97. obj._cols = cols
  98. obj._rep = rep
  99. return obj
  100. # make sure ImmutableDenseMatrix is aliased as ImmutableMatrix
  101. ImmutableMatrix = ImmutableDenseMatrix
  102. class ImmutableSparseMatrix(SparseRepMatrix, ImmutableRepMatrix): # type:ignore
  103. """Create an immutable version of a sparse matrix.
  104. Examples
  105. ========
  106. >>> from sympy import eye, ImmutableSparseMatrix
  107. >>> ImmutableSparseMatrix(1, 1, {})
  108. Matrix([[0]])
  109. >>> ImmutableSparseMatrix(eye(3))
  110. Matrix([
  111. [1, 0, 0],
  112. [0, 1, 0],
  113. [0, 0, 1]])
  114. >>> _[0, 0] = 42
  115. Traceback (most recent call last):
  116. ...
  117. TypeError: Cannot set values of ImmutableSparseMatrix
  118. >>> _.shape
  119. (3, 3)
  120. """
  121. is_Matrix = True
  122. _class_priority = 9
  123. @classmethod
  124. def _new(cls, *args, **kwargs):
  125. rows, cols, smat = cls._handle_creation_inputs(*args, **kwargs)
  126. rep = cls._smat_to_DomainMatrix(rows, cols, smat)
  127. return cls._fromrep(rep)
  128. @classmethod
  129. def _fromrep(cls, rep):
  130. rows, cols = rep.shape
  131. smat = rep.to_sympy().to_dok()
  132. obj = Basic.__new__(cls, Integer(rows), Integer(cols), Dict(smat))
  133. obj._rows = rows
  134. obj._cols = cols
  135. obj._rep = rep
  136. return obj
  137. @dispatch(ImmutableDenseMatrix, ImmutableDenseMatrix)
  138. def _eval_is_eq(lhs, rhs): # noqa:F811
  139. """Helper method for Equality with matrices.sympy.
  140. Relational automatically converts matrices to ImmutableDenseMatrix
  141. instances, so this method only applies here. Returns True if the
  142. matrices are definitively the same, False if they are definitively
  143. different, and None if undetermined (e.g. if they contain Symbols).
  144. Returning None triggers default handling of Equalities.
  145. """
  146. if lhs.shape != rhs.shape:
  147. return False
  148. return (lhs - rhs).is_zero_matrix