test_determinant.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import random
  2. import pytest
  3. from sympy.core.numbers import I
  4. from sympy.core.numbers import Rational
  5. from sympy.core.symbol import (Symbol, symbols)
  6. from sympy.functions.elementary.miscellaneous import sqrt
  7. from sympy.polys.polytools import Poly
  8. from sympy.matrices import Matrix, eye, ones
  9. from sympy.abc import x, y, z
  10. from sympy.testing.pytest import raises
  11. from sympy.matrices.exceptions import NonSquareMatrixError
  12. from sympy.functions.combinatorial.factorials import factorial, subfactorial
  13. @pytest.mark.parametrize("method", [
  14. # Evaluating these directly because they are never reached via M.det()
  15. Matrix._eval_det_bareiss, Matrix._eval_det_berkowitz,
  16. Matrix._eval_det_bird, Matrix._eval_det_laplace, Matrix._eval_det_lu
  17. ])
  18. @pytest.mark.parametrize("M, sol", [
  19. (Matrix(), 1),
  20. (Matrix([[0]]), 0),
  21. (Matrix([[5]]), 5),
  22. ])
  23. def test_eval_determinant(method, M, sol):
  24. assert method(M) == sol
  25. @pytest.mark.parametrize("method", [
  26. "domain-ge", "bareiss", "berkowitz", "bird", "laplace", "lu"])
  27. @pytest.mark.parametrize("M, sol", [
  28. (Matrix(( (-3, 2),
  29. ( 8, -5) )), -1),
  30. (Matrix(( (x, 1),
  31. (y, 2*y) )), 2*x*y - y),
  32. (Matrix(( (1, 1, 1),
  33. (1, 2, 3),
  34. (1, 3, 6) )), 1),
  35. (Matrix(( ( 3, -2, 0, 5),
  36. (-2, 1, -2, 2),
  37. ( 0, -2, 5, 0),
  38. ( 5, 0, 3, 4) )), -289),
  39. (Matrix(( ( 1, 2, 3, 4),
  40. ( 5, 6, 7, 8),
  41. ( 9, 10, 11, 12),
  42. (13, 14, 15, 16) )), 0),
  43. (Matrix(( (3, 2, 0, 0, 0),
  44. (0, 3, 2, 0, 0),
  45. (0, 0, 3, 2, 0),
  46. (0, 0, 0, 3, 2),
  47. (2, 0, 0, 0, 3) )), 275),
  48. (Matrix(( ( 3, 0, 0, 0),
  49. (-2, 1, 0, 0),
  50. ( 0, -2, 5, 0),
  51. ( 5, 0, 3, 4) )), 60),
  52. (Matrix(( ( 1, 0, 0, 0),
  53. ( 5, 0, 0, 0),
  54. ( 9, 10, 11, 0),
  55. (13, 14, 15, 16) )), 0),
  56. (Matrix(( (3, 2, 0, 0, 0),
  57. (0, 3, 2, 0, 0),
  58. (0, 0, 3, 2, 0),
  59. (0, 0, 0, 3, 2),
  60. (0, 0, 0, 0, 3) )), 243),
  61. (Matrix(( (1, 0, 1, 2, 12),
  62. (2, 0, 1, 1, 4),
  63. (2, 1, 1, -1, 3),
  64. (3, 2, -1, 1, 8),
  65. (1, 1, 1, 0, 6) )), -55),
  66. (Matrix(( (-5, 2, 3, 4, 5),
  67. ( 1, -4, 3, 4, 5),
  68. ( 1, 2, -3, 4, 5),
  69. ( 1, 2, 3, -2, 5),
  70. ( 1, 2, 3, 4, -1) )), 11664),
  71. (Matrix(( ( 2, 7, -1, 3, 2),
  72. ( 0, 0, 1, 0, 1),
  73. (-2, 0, 7, 0, 2),
  74. (-3, -2, 4, 5, 3),
  75. ( 1, 0, 0, 0, 1) )), 123),
  76. (Matrix(( (x, y, z),
  77. (1, 0, 0),
  78. (y, z, x) )), z**2 - x*y),
  79. ])
  80. def test_determinant(method, M, sol):
  81. assert M.det(method=method) == sol
  82. def test_issue_13835():
  83. a = symbols('a')
  84. M = lambda n: Matrix([[i + a*j for i in range(n)]
  85. for j in range(n)])
  86. assert M(5).det() == 0
  87. assert M(6).det() == 0
  88. assert M(7).det() == 0
  89. def test_issue_14517():
  90. M = Matrix([
  91. [ 0, 10*I, 10*I, 0],
  92. [10*I, 0, 0, 10*I],
  93. [10*I, 0, 5 + 2*I, 10*I],
  94. [ 0, 10*I, 10*I, 5 + 2*I]])
  95. ev = M.eigenvals()
  96. # test one random eigenvalue, the computation is a little slow
  97. test_ev = random.choice(list(ev.keys()))
  98. assert (M - test_ev*eye(4)).det() == 0
  99. @pytest.mark.parametrize("method", [
  100. "bareis", "det_lu", "det_LU", "Bareis", "BAREISS", "BERKOWITZ", "LU"])
  101. @pytest.mark.parametrize("M, sol", [
  102. (Matrix(( ( 3, -2, 0, 5),
  103. (-2, 1, -2, 2),
  104. ( 0, -2, 5, 0),
  105. ( 5, 0, 3, 4) )), -289),
  106. (Matrix(( (-5, 2, 3, 4, 5),
  107. ( 1, -4, 3, 4, 5),
  108. ( 1, 2, -3, 4, 5),
  109. ( 1, 2, 3, -2, 5),
  110. ( 1, 2, 3, 4, -1) )), 11664),
  111. ])
  112. def test_legacy_det(method, M, sol):
  113. # Minimal support for legacy keys for 'method' in det()
  114. # Partially copied from test_determinant()
  115. assert M.det(method=method) == sol
  116. def eye_Determinant(n):
  117. return Matrix(n, n, lambda i, j: int(i == j))
  118. def zeros_Determinant(n):
  119. return Matrix(n, n, lambda i, j: 0)
  120. def test_det():
  121. a = Matrix(2, 3, [1, 2, 3, 4, 5, 6])
  122. raises(NonSquareMatrixError, lambda: a.det())
  123. z = zeros_Determinant(2)
  124. ey = eye_Determinant(2)
  125. assert z.det() == 0
  126. assert ey.det() == 1
  127. x = Symbol('x')
  128. a = Matrix(0, 0, [])
  129. b = Matrix(1, 1, [5])
  130. c = Matrix(2, 2, [1, 2, 3, 4])
  131. d = Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 8])
  132. e = Matrix(4, 4,
  133. [x, 1, 2, 3, 4, 5, 6, 7, 2, 9, 10, 11, 12, 13, 14, 14])
  134. from sympy.abc import i, j, k, l, m, n
  135. f = Matrix(3, 3, [i, l, m, 0, j, n, 0, 0, k])
  136. g = Matrix(3, 3, [i, 0, 0, l, j, 0, m, n, k])
  137. h = Matrix(3, 3, [x**3, 0, 0, i, x**-1, 0, j, k, x**-2])
  138. # the method keyword for `det` doesn't kick in until 4x4 matrices,
  139. # so there is no need to test all methods on smaller ones
  140. assert a.det() == 1
  141. assert b.det() == 5
  142. assert c.det() == -2
  143. assert d.det() == 3
  144. assert e.det() == 4*x - 24
  145. assert e.det(method="domain-ge") == 4*x - 24
  146. assert e.det(method='bareiss') == 4*x - 24
  147. assert e.det(method='berkowitz') == 4*x - 24
  148. assert f.det() == i*j*k
  149. assert g.det() == i*j*k
  150. assert h.det() == 1
  151. raises(ValueError, lambda: e.det(iszerofunc="test"))
  152. def test_permanent():
  153. M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  154. assert M.per() == 450
  155. for i in range(1, 12):
  156. assert ones(i, i).per() == ones(i, i).T.per() == factorial(i)
  157. assert (ones(i, i)-eye(i)).per() == (ones(i, i)-eye(i)).T.per() == subfactorial(i)
  158. a1, a2, a3, a4, a5 = symbols('a_1 a_2 a_3 a_4 a_5')
  159. M = Matrix([a1, a2, a3, a4, a5])
  160. assert M.per() == M.T.per() == a1 + a2 + a3 + a4 + a5
  161. def test_adjugate():
  162. x = Symbol('x')
  163. e = Matrix(4, 4,
  164. [x, 1, 2, 3, 4, 5, 6, 7, 2, 9, 10, 11, 12, 13, 14, 14])
  165. adj = Matrix([
  166. [ 4, -8, 4, 0],
  167. [ 76, -14*x - 68, 14*x - 8, -4*x + 24],
  168. [-122, 17*x + 142, -21*x + 4, 8*x - 48],
  169. [ 48, -4*x - 72, 8*x, -4*x + 24]])
  170. assert e.adjugate() == adj
  171. assert e.adjugate(method='bareiss') == adj
  172. assert e.adjugate(method='berkowitz') == adj
  173. assert e.adjugate(method='bird') == adj
  174. assert e.adjugate(method='laplace') == adj
  175. a = Matrix(2, 3, [1, 2, 3, 4, 5, 6])
  176. raises(NonSquareMatrixError, lambda: a.adjugate())
  177. def test_util():
  178. R = Rational
  179. v1 = Matrix(1, 3, [1, 2, 3])
  180. v2 = Matrix(1, 3, [3, 4, 5])
  181. assert v1.norm() == sqrt(14)
  182. assert v1.project(v2) == Matrix(1, 3, [R(39)/25, R(52)/25, R(13)/5])
  183. assert Matrix.zeros(1, 2) == Matrix(1, 2, [0, 0])
  184. assert ones(1, 2) == Matrix(1, 2, [1, 1])
  185. assert v1.copy() == v1
  186. # cofactor
  187. assert eye(3) == eye(3).cofactor_matrix()
  188. test = Matrix([[1, 3, 2], [2, 6, 3], [2, 3, 6]])
  189. assert test.cofactor_matrix() == \
  190. Matrix([[27, -6, -6], [-12, 2, 3], [-3, 1, 0]])
  191. test = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  192. assert test.cofactor_matrix() == \
  193. Matrix([[-3, 6, -3], [6, -12, 6], [-3, 6, -3]])
  194. def test_cofactor_and_minors():
  195. x = Symbol('x')
  196. e = Matrix(4, 4,
  197. [x, 1, 2, 3, 4, 5, 6, 7, 2, 9, 10, 11, 12, 13, 14, 14])
  198. m = Matrix([
  199. [ x, 1, 3],
  200. [ 2, 9, 11],
  201. [12, 13, 14]])
  202. cm = Matrix([
  203. [ 4, 76, -122, 48],
  204. [-8, -14*x - 68, 17*x + 142, -4*x - 72],
  205. [ 4, 14*x - 8, -21*x + 4, 8*x],
  206. [ 0, -4*x + 24, 8*x - 48, -4*x + 24]])
  207. sub = Matrix([
  208. [x, 1, 2],
  209. [4, 5, 6],
  210. [2, 9, 10]])
  211. assert e.minor_submatrix(1, 2) == m
  212. assert e.minor_submatrix(-1, -1) == sub
  213. assert e.minor(1, 2) == -17*x - 142
  214. assert e.cofactor(1, 2) == 17*x + 142
  215. assert e.cofactor_matrix() == cm
  216. assert e.cofactor_matrix(method="bareiss") == cm
  217. assert e.cofactor_matrix(method="berkowitz") == cm
  218. assert e.cofactor_matrix(method="bird") == cm
  219. assert e.cofactor_matrix(method="laplace") == cm
  220. raises(ValueError, lambda: e.cofactor(4, 5))
  221. raises(ValueError, lambda: e.minor(4, 5))
  222. raises(ValueError, lambda: e.minor_submatrix(4, 5))
  223. a = Matrix(2, 3, [1, 2, 3, 4, 5, 6])
  224. assert a.minor_submatrix(0, 0) == Matrix([[5, 6]])
  225. raises(ValueError, lambda:
  226. Matrix(0, 0, []).minor_submatrix(0, 0))
  227. raises(NonSquareMatrixError, lambda: a.cofactor(0, 0))
  228. raises(NonSquareMatrixError, lambda: a.minor(0, 0))
  229. raises(NonSquareMatrixError, lambda: a.cofactor_matrix())
  230. def test_charpoly():
  231. x, y = Symbol('x'), Symbol('y')
  232. z, t = Symbol('z'), Symbol('t')
  233. from sympy.abc import a,b,c
  234. m = Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9])
  235. assert eye_Determinant(3).charpoly(x) == Poly((x - 1)**3, x)
  236. assert eye_Determinant(3).charpoly(y) == Poly((y - 1)**3, y)
  237. assert m.charpoly() == Poly(x**3 - 15*x**2 - 18*x, x)
  238. raises(NonSquareMatrixError, lambda: Matrix([[1], [2]]).charpoly())
  239. n = Matrix(4, 4, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
  240. assert n.charpoly() == Poly(x**4, x)
  241. n = Matrix(4, 4, [45, 0, 0, 0, 0, 23, 0, 0, 0, 0, 87, 0, 0, 0, 0, 12])
  242. assert n.charpoly() == Poly(x**4 - 167*x**3 + 8811*x**2 - 173457*x + 1080540, x)
  243. n = Matrix(3, 3, [x, 0, 0, a, y, 0, b, c, z])
  244. assert n.charpoly() == Poly(t**3 - (x+y+z)*t**2 + t*(x*y+y*z+x*z) - x*y*z, t)