test_xxm.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023
  1. #
  2. # Test basic features of DDM, SDM and DFM.
  3. #
  4. # These three types are supposed to be interchangeable, so we should use the
  5. # same tests for all of them for the most part.
  6. #
  7. # The tests here cover the basic part of the interface that the three types
  8. # should expose and that DomainMatrix should mostly rely on.
  9. #
  10. # More in-depth tests of the heavier algorithms like rref etc should go in
  11. # their own test files.
  12. #
  13. # Any new methods added to the DDM, SDM or DFM classes should be tested here
  14. # and added to all classes.
  15. #
  16. from sympy.external.gmpy import GROUND_TYPES
  17. from sympy import ZZ, QQ, GF, ZZ_I, symbols
  18. from sympy.polys.matrices.exceptions import (
  19. DMBadInputError,
  20. DMDomainError,
  21. DMNonSquareMatrixError,
  22. DMNonInvertibleMatrixError,
  23. DMShapeError,
  24. )
  25. from sympy.polys.matrices.domainmatrix import DM, DomainMatrix, DDM, SDM, DFM
  26. from sympy.testing.pytest import raises, skip
  27. import pytest
  28. def test_XXM_constructors():
  29. """Test the DDM, etc constructors."""
  30. lol = [
  31. [ZZ(1), ZZ(2)],
  32. [ZZ(3), ZZ(4)],
  33. [ZZ(5), ZZ(6)],
  34. ]
  35. dod = {
  36. 0: {0: ZZ(1), 1: ZZ(2)},
  37. 1: {0: ZZ(3), 1: ZZ(4)},
  38. 2: {0: ZZ(5), 1: ZZ(6)},
  39. }
  40. lol_0x0 = []
  41. lol_0x2 = []
  42. lol_2x0 = [[], []]
  43. dod_0x0 = {}
  44. dod_0x2 = {}
  45. dod_2x0 = {}
  46. lol_bad = [
  47. [ZZ(1), ZZ(2)],
  48. [ZZ(3), ZZ(4)],
  49. [ZZ(5), ZZ(6), ZZ(7)],
  50. ]
  51. dod_bad = {
  52. 0: {0: ZZ(1), 1: ZZ(2)},
  53. 1: {0: ZZ(3), 1: ZZ(4)},
  54. 2: {0: ZZ(5), 1: ZZ(6), 2: ZZ(7)},
  55. }
  56. XDM_dense = [DDM]
  57. XDM_sparse = [SDM]
  58. if GROUND_TYPES == 'flint':
  59. XDM_dense.append(DFM)
  60. for XDM in XDM_dense:
  61. A = XDM(lol, (3, 2), ZZ)
  62. assert A.rows == 3
  63. assert A.cols == 2
  64. assert A.domain == ZZ
  65. assert A.shape == (3, 2)
  66. if XDM is not DFM:
  67. assert ZZ.of_type(A[0][0]) is True
  68. else:
  69. assert ZZ.of_type(A.rep[0, 0]) is True
  70. Adm = DomainMatrix(lol, (3, 2), ZZ)
  71. if XDM is DFM:
  72. assert Adm.rep == A
  73. assert Adm.rep.to_ddm() != A
  74. elif GROUND_TYPES == 'flint':
  75. assert Adm.rep.to_ddm() == A
  76. assert Adm.rep != A
  77. else:
  78. assert Adm.rep == A
  79. assert Adm.rep.to_ddm() == A
  80. assert XDM(lol_0x0, (0, 0), ZZ).shape == (0, 0)
  81. assert XDM(lol_0x2, (0, 2), ZZ).shape == (0, 2)
  82. assert XDM(lol_2x0, (2, 0), ZZ).shape == (2, 0)
  83. raises(DMBadInputError, lambda: XDM(lol, (2, 3), ZZ))
  84. raises(DMBadInputError, lambda: XDM(lol_bad, (3, 2), ZZ))
  85. raises(DMBadInputError, lambda: XDM(dod, (3, 2), ZZ))
  86. for XDM in XDM_sparse:
  87. A = XDM(dod, (3, 2), ZZ)
  88. assert A.rows == 3
  89. assert A.cols == 2
  90. assert A.domain == ZZ
  91. assert A.shape == (3, 2)
  92. assert ZZ.of_type(A[0][0]) is True
  93. assert DomainMatrix(dod, (3, 2), ZZ).rep == A
  94. assert XDM(dod_0x0, (0, 0), ZZ).shape == (0, 0)
  95. assert XDM(dod_0x2, (0, 2), ZZ).shape == (0, 2)
  96. assert XDM(dod_2x0, (2, 0), ZZ).shape == (2, 0)
  97. raises(DMBadInputError, lambda: XDM(dod, (2, 3), ZZ))
  98. raises(DMBadInputError, lambda: XDM(lol, (3, 2), ZZ))
  99. raises(DMBadInputError, lambda: XDM(dod_bad, (3, 2), ZZ))
  100. raises(DMBadInputError, lambda: DomainMatrix(lol, (2, 3), ZZ))
  101. raises(DMBadInputError, lambda: DomainMatrix(lol_bad, (3, 2), ZZ))
  102. raises(DMBadInputError, lambda: DomainMatrix(dod_bad, (3, 2), ZZ))
  103. def test_XXM_eq():
  104. """Test equality for DDM, SDM, DFM and DomainMatrix."""
  105. lol1 = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]]
  106. dod1 = {0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}}
  107. lol2 = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(5)]]
  108. dod2 = {0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(5)}}
  109. A1_ddm = DDM(lol1, (2, 2), ZZ)
  110. A1_sdm = SDM(dod1, (2, 2), ZZ)
  111. A1_dm_d = DomainMatrix(lol1, (2, 2), ZZ)
  112. A1_dm_s = DomainMatrix(dod1, (2, 2), ZZ)
  113. A2_ddm = DDM(lol2, (2, 2), ZZ)
  114. A2_sdm = SDM(dod2, (2, 2), ZZ)
  115. A2_dm_d = DomainMatrix(lol2, (2, 2), ZZ)
  116. A2_dm_s = DomainMatrix(dod2, (2, 2), ZZ)
  117. A1_all = [A1_ddm, A1_sdm, A1_dm_d, A1_dm_s]
  118. A2_all = [A2_ddm, A2_sdm, A2_dm_d, A2_dm_s]
  119. if GROUND_TYPES == 'flint':
  120. A1_dfm = DFM([[1, 2], [3, 4]], (2, 2), ZZ)
  121. A2_dfm = DFM([[1, 2], [3, 5]], (2, 2), ZZ)
  122. A1_all.append(A1_dfm)
  123. A2_all.append(A2_dfm)
  124. for n, An in enumerate(A1_all):
  125. for m, Am in enumerate(A1_all):
  126. if n == m:
  127. assert (An == Am) is True
  128. assert (An != Am) is False
  129. else:
  130. assert (An == Am) is False
  131. assert (An != Am) is True
  132. for n, An in enumerate(A2_all):
  133. for m, Am in enumerate(A2_all):
  134. if n == m:
  135. assert (An == Am) is True
  136. assert (An != Am) is False
  137. else:
  138. assert (An == Am) is False
  139. assert (An != Am) is True
  140. for n, A1 in enumerate(A1_all):
  141. for m, A2 in enumerate(A2_all):
  142. assert (A1 == A2) is False
  143. assert (A1 != A2) is True
  144. def test_to_XXM():
  145. """Test to_ddm etc. for DDM, SDM, DFM and DomainMatrix."""
  146. lol = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]]
  147. dod = {0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}}
  148. A_ddm = DDM(lol, (2, 2), ZZ)
  149. A_sdm = SDM(dod, (2, 2), ZZ)
  150. A_dm_d = DomainMatrix(lol, (2, 2), ZZ)
  151. A_dm_s = DomainMatrix(dod, (2, 2), ZZ)
  152. A_all = [A_ddm, A_sdm, A_dm_d, A_dm_s]
  153. if GROUND_TYPES == 'flint':
  154. A_dfm = DFM(lol, (2, 2), ZZ)
  155. A_all.append(A_dfm)
  156. for A in A_all:
  157. assert A.to_ddm() == A_ddm
  158. assert A.to_sdm() == A_sdm
  159. if GROUND_TYPES != 'flint':
  160. raises(NotImplementedError, lambda: A.to_dfm())
  161. assert A.to_dfm_or_ddm() == A_ddm
  162. # Add e.g. DDM.to_DM()?
  163. # assert A.to_DM() == A_dm
  164. if GROUND_TYPES == 'flint':
  165. for A in A_all:
  166. assert A.to_dfm() == A_dfm
  167. for K in [ZZ, QQ, GF(5), ZZ_I]:
  168. if isinstance(A, DFM) and not DFM._supports_domain(K):
  169. raises(NotImplementedError, lambda: A.convert_to(K))
  170. else:
  171. A_K = A.convert_to(K)
  172. if DFM._supports_domain(K):
  173. A_dfm_K = A_dfm.convert_to(K)
  174. assert A_K.to_dfm() == A_dfm_K
  175. assert A_K.to_dfm_or_ddm() == A_dfm_K
  176. else:
  177. raises(NotImplementedError, lambda: A_K.to_dfm())
  178. assert A_K.to_dfm_or_ddm() == A_ddm.convert_to(K)
  179. def test_DFM_domains():
  180. """Test which domains are supported by DFM."""
  181. x, y = symbols('x, y')
  182. if GROUND_TYPES in ('python', 'gmpy'):
  183. supported = []
  184. flint_funcs = {}
  185. not_supported = [ZZ, QQ, GF(5), QQ[x], QQ[x,y]]
  186. elif GROUND_TYPES == 'flint':
  187. import flint
  188. supported = [ZZ, QQ]
  189. flint_funcs = {
  190. ZZ: flint.fmpz_mat,
  191. QQ: flint.fmpq_mat,
  192. GF(5): None,
  193. }
  194. not_supported = [
  195. # Other domains could be supported but not implemented as matrices
  196. # in python-flint:
  197. QQ[x],
  198. QQ[x,y],
  199. QQ.frac_field(x,y),
  200. # Others would potentially never be supported by python-flint:
  201. ZZ_I,
  202. ]
  203. else:
  204. assert False, "Unknown GROUND_TYPES: %s" % GROUND_TYPES
  205. for domain in supported:
  206. assert DFM._supports_domain(domain) is True
  207. if flint_funcs[domain] is not None:
  208. assert DFM._get_flint_func(domain) == flint_funcs[domain]
  209. for domain in not_supported:
  210. assert DFM._supports_domain(domain) is False
  211. raises(NotImplementedError, lambda: DFM._get_flint_func(domain))
  212. def _DM(lol, typ, K):
  213. """Make a DM of type typ over K from lol."""
  214. A = DM(lol, K)
  215. if typ == 'DDM':
  216. return A.to_ddm()
  217. elif typ == 'SDM':
  218. return A.to_sdm()
  219. elif typ == 'DFM':
  220. if GROUND_TYPES != 'flint':
  221. skip("DFM not supported in this ground type")
  222. return A.to_dfm()
  223. else:
  224. assert False, "Unknown type %s" % typ
  225. def _DMZ(lol, typ):
  226. """Make a DM of type typ over ZZ from lol."""
  227. return _DM(lol, typ, ZZ)
  228. def _DMQ(lol, typ):
  229. """Make a DM of type typ over QQ from lol."""
  230. return _DM(lol, typ, QQ)
  231. def DM_ddm(lol, K):
  232. """Make a DDM over K from lol."""
  233. return _DM(lol, 'DDM', K)
  234. def DM_sdm(lol, K):
  235. """Make a SDM over K from lol."""
  236. return _DM(lol, 'SDM', K)
  237. def DM_dfm(lol, K):
  238. """Make a DFM over K from lol."""
  239. return _DM(lol, 'DFM', K)
  240. def DMZ_ddm(lol):
  241. """Make a DDM from lol."""
  242. return _DMZ(lol, 'DDM')
  243. def DMZ_sdm(lol):
  244. """Make a SDM from lol."""
  245. return _DMZ(lol, 'SDM')
  246. def DMZ_dfm(lol):
  247. """Make a DFM from lol."""
  248. return _DMZ(lol, 'DFM')
  249. def DMQ_ddm(lol):
  250. """Make a DDM from lol."""
  251. return _DMQ(lol, 'DDM')
  252. def DMQ_sdm(lol):
  253. """Make a SDM from lol."""
  254. return _DMQ(lol, 'SDM')
  255. def DMQ_dfm(lol):
  256. """Make a DFM from lol."""
  257. return _DMQ(lol, 'DFM')
  258. DM_all = [DM_ddm, DM_sdm, DM_dfm]
  259. DMZ_all = [DMZ_ddm, DMZ_sdm, DMZ_dfm]
  260. DMQ_all = [DMQ_ddm, DMQ_sdm, DMQ_dfm]
  261. @pytest.mark.parametrize('DM', DMZ_all)
  262. def test_XDM_getitem(DM):
  263. """Test getitem for DDM, etc."""
  264. lol = [[0, 1], [2, 0]]
  265. A = DM(lol)
  266. m, n = A.shape
  267. indices = [-3, -2, -1, 0, 1, 2]
  268. for i in indices:
  269. for j in indices:
  270. if -2 <= i < m and -2 <= j < n:
  271. assert A.getitem(i, j) == ZZ(lol[i][j])
  272. else:
  273. raises(IndexError, lambda: A.getitem(i, j))
  274. @pytest.mark.parametrize('DM', DMZ_all)
  275. def test_XDM_setitem(DM):
  276. """Test setitem for DDM, etc."""
  277. A = DM([[0, 1, 2], [3, 4, 5]])
  278. A.setitem(0, 0, ZZ(6))
  279. assert A == DM([[6, 1, 2], [3, 4, 5]])
  280. A.setitem(0, 1, ZZ(7))
  281. assert A == DM([[6, 7, 2], [3, 4, 5]])
  282. A.setitem(0, 2, ZZ(8))
  283. assert A == DM([[6, 7, 8], [3, 4, 5]])
  284. A.setitem(0, -1, ZZ(9))
  285. assert A == DM([[6, 7, 9], [3, 4, 5]])
  286. A.setitem(0, -2, ZZ(10))
  287. assert A == DM([[6, 10, 9], [3, 4, 5]])
  288. A.setitem(0, -3, ZZ(11))
  289. assert A == DM([[11, 10, 9], [3, 4, 5]])
  290. raises(IndexError, lambda: A.setitem(0, 3, ZZ(12)))
  291. raises(IndexError, lambda: A.setitem(0, -4, ZZ(13)))
  292. A.setitem(1, 0, ZZ(14))
  293. assert A == DM([[11, 10, 9], [14, 4, 5]])
  294. A.setitem(1, 1, ZZ(15))
  295. assert A == DM([[11, 10, 9], [14, 15, 5]])
  296. A.setitem(-1, 1, ZZ(16))
  297. assert A == DM([[11, 10, 9], [14, 16, 5]])
  298. A.setitem(-2, 1, ZZ(17))
  299. assert A == DM([[11, 17, 9], [14, 16, 5]])
  300. raises(IndexError, lambda: A.setitem(2, 0, ZZ(18)))
  301. raises(IndexError, lambda: A.setitem(-3, 0, ZZ(19)))
  302. A.setitem(1, 2, ZZ(0))
  303. assert A == DM([[11, 17, 9], [14, 16, 0]])
  304. A.setitem(1, -2, ZZ(0))
  305. assert A == DM([[11, 17, 9], [14, 0, 0]])
  306. A.setitem(1, -3, ZZ(0))
  307. assert A == DM([[11, 17, 9], [0, 0, 0]])
  308. A.setitem(0, 0, ZZ(0))
  309. assert A == DM([[0, 17, 9], [0, 0, 0]])
  310. A.setitem(0, -1, ZZ(0))
  311. assert A == DM([[0, 17, 0], [0, 0, 0]])
  312. A.setitem(0, 0, ZZ(0))
  313. assert A == DM([[0, 17, 0], [0, 0, 0]])
  314. A.setitem(0, -2, ZZ(0))
  315. assert A == DM([[0, 0, 0], [0, 0, 0]])
  316. A.setitem(0, -3, ZZ(1))
  317. assert A == DM([[1, 0, 0], [0, 0, 0]])
  318. class _Sliced:
  319. def __getitem__(self, item):
  320. return item
  321. _slice = _Sliced()
  322. @pytest.mark.parametrize('DM', DMZ_all)
  323. def test_XXM_extract_slice(DM):
  324. A = DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  325. assert A.extract_slice(*_slice[:,:]) == A
  326. assert A.extract_slice(*_slice[1:,:]) == DM([[4, 5, 6], [7, 8, 9]])
  327. assert A.extract_slice(*_slice[1:,1:]) == DM([[5, 6], [8, 9]])
  328. assert A.extract_slice(*_slice[1:,:-1]) == DM([[4, 5], [7, 8]])
  329. assert A.extract_slice(*_slice[1:,:-1:2]) == DM([[4], [7]])
  330. assert A.extract_slice(*_slice[:,::2]) == DM([[1, 3], [4, 6], [7, 9]])
  331. assert A.extract_slice(*_slice[::2,:]) == DM([[1, 2, 3], [7, 8, 9]])
  332. assert A.extract_slice(*_slice[::2,::2]) == DM([[1, 3], [7, 9]])
  333. assert A.extract_slice(*_slice[::2,::-2]) == DM([[3, 1], [9, 7]])
  334. assert A.extract_slice(*_slice[::-2,::2]) == DM([[7, 9], [1, 3]])
  335. assert A.extract_slice(*_slice[::-2,::-2]) == DM([[9, 7], [3, 1]])
  336. assert A.extract_slice(*_slice[:,::-1]) == DM([[3, 2, 1], [6, 5, 4], [9, 8, 7]])
  337. assert A.extract_slice(*_slice[::-1,:]) == DM([[7, 8, 9], [4, 5, 6], [1, 2, 3]])
  338. @pytest.mark.parametrize('DM', DMZ_all)
  339. def test_XXM_extract(DM):
  340. A = DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  341. assert A.extract([0, 1, 2], [0, 1, 2]) == A
  342. assert A.extract([1, 2], [1, 2]) == DM([[5, 6], [8, 9]])
  343. assert A.extract([1, 2], [0, 1]) == DM([[4, 5], [7, 8]])
  344. assert A.extract([1, 2], [0, 2]) == DM([[4, 6], [7, 9]])
  345. assert A.extract([1, 2], [0]) == DM([[4], [7]])
  346. assert A.extract([1, 2], []) == DM([[1]]).zeros((2, 0), ZZ)
  347. assert A.extract([], [0, 1, 2]) == DM([[1]]).zeros((0, 3), ZZ)
  348. raises(IndexError, lambda: A.extract([1, 2], [0, 3]))
  349. raises(IndexError, lambda: A.extract([1, 2], [0, -4]))
  350. raises(IndexError, lambda: A.extract([3, 1], [0, 1]))
  351. raises(IndexError, lambda: A.extract([-4, 2], [3, 1]))
  352. B = DM([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
  353. assert B.extract([1, 2], [1, 2]) == DM([[0, 0], [0, 0]])
  354. def test_XXM_str():
  355. A = DomainMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]], (3, 3), ZZ)
  356. assert str(A) == \
  357. 'DomainMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]], (3, 3), ZZ)'
  358. assert str(A.to_ddm()) == \
  359. '[[1, 2, 3], [4, 5, 6], [7, 8, 9]]'
  360. assert str(A.to_sdm()) == \
  361. '{0: {0: 1, 1: 2, 2: 3}, 1: {0: 4, 1: 5, 2: 6}, 2: {0: 7, 1: 8, 2: 9}}'
  362. assert repr(A) == \
  363. 'DomainMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]], (3, 3), ZZ)'
  364. assert repr(A.to_ddm()) == \
  365. 'DDM([[1, 2, 3], [4, 5, 6], [7, 8, 9]], (3, 3), ZZ)'
  366. assert repr(A.to_sdm()) == \
  367. 'SDM({0: {0: 1, 1: 2, 2: 3}, 1: {0: 4, 1: 5, 2: 6}, 2: {0: 7, 1: 8, 2: 9}}, (3, 3), ZZ)'
  368. B = DomainMatrix({0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3)}}, (2, 2), ZZ)
  369. assert str(B) == \
  370. 'DomainMatrix({0: {0: 1, 1: 2}, 1: {0: 3}}, (2, 2), ZZ)'
  371. assert str(B.to_ddm()) == \
  372. '[[1, 2], [3, 0]]'
  373. assert str(B.to_sdm()) == \
  374. '{0: {0: 1, 1: 2}, 1: {0: 3}}'
  375. assert repr(B) == \
  376. 'DomainMatrix({0: {0: 1, 1: 2}, 1: {0: 3}}, (2, 2), ZZ)'
  377. if GROUND_TYPES != 'gmpy':
  378. assert repr(B.to_ddm()) == \
  379. 'DDM([[1, 2], [3, 0]], (2, 2), ZZ)'
  380. assert repr(B.to_sdm()) == \
  381. 'SDM({0: {0: 1, 1: 2}, 1: {0: 3}}, (2, 2), ZZ)'
  382. else:
  383. assert repr(B.to_ddm()) == \
  384. 'DDM([[mpz(1), mpz(2)], [mpz(3), mpz(0)]], (2, 2), ZZ)'
  385. assert repr(B.to_sdm()) == \
  386. 'SDM({0: {0: mpz(1), 1: mpz(2)}, 1: {0: mpz(3)}}, (2, 2), ZZ)'
  387. if GROUND_TYPES == 'flint':
  388. assert str(A.to_dfm()) == \
  389. '[[1, 2, 3], [4, 5, 6], [7, 8, 9]]'
  390. assert str(B.to_dfm()) == \
  391. '[[1, 2], [3, 0]]'
  392. assert repr(A.to_dfm()) == \
  393. 'DFM([[1, 2, 3], [4, 5, 6], [7, 8, 9]], (3, 3), ZZ)'
  394. assert repr(B.to_dfm()) == \
  395. 'DFM([[1, 2], [3, 0]], (2, 2), ZZ)'
  396. @pytest.mark.parametrize('DM', DMZ_all)
  397. def test_XXM_from_list(DM):
  398. T = type(DM([[0]]))
  399. lol = [[1, 2, 4], [4, 5, 6]]
  400. lol_ZZ = [[ZZ(1), ZZ(2), ZZ(4)], [ZZ(4), ZZ(5), ZZ(6)]]
  401. lol_ZZ_bad = [[ZZ(1), ZZ(2), ZZ(4)], [ZZ(4), ZZ(5), ZZ(6), ZZ(7)]]
  402. assert T.from_list(lol_ZZ, (2, 3), ZZ) == DM(lol)
  403. raises(DMBadInputError, lambda: T.from_list(lol_ZZ_bad, (3, 2), ZZ))
  404. @pytest.mark.parametrize('DM', DMZ_all)
  405. def test_XXM_to_list(DM):
  406. lol = [[1, 2, 4], [4, 5, 6]]
  407. assert DM(lol).to_list() == [[ZZ(1), ZZ(2), ZZ(4)], [ZZ(4), ZZ(5), ZZ(6)]]
  408. @pytest.mark.parametrize('DM', DMZ_all)
  409. def test_XXM_to_list_flat(DM):
  410. lol = [[1, 2, 4], [4, 5, 6]]
  411. assert DM(lol).to_list_flat() == [ZZ(1), ZZ(2), ZZ(4), ZZ(4), ZZ(5), ZZ(6)]
  412. @pytest.mark.parametrize('DM', DMZ_all)
  413. def test_XXM_from_list_flat(DM):
  414. T = type(DM([[0]]))
  415. flat = [ZZ(1), ZZ(2), ZZ(4), ZZ(4), ZZ(5), ZZ(6)]
  416. assert T.from_list_flat(flat, (2, 3), ZZ) == DM([[1, 2, 4], [4, 5, 6]])
  417. raises(DMBadInputError, lambda: T.from_list_flat(flat, (3, 3), ZZ))
  418. @pytest.mark.parametrize('DM', DMZ_all)
  419. def test_XXM_to_flat_nz(DM):
  420. M = DM([[1, 2, 0], [0, 0, 0], [0, 0, 3]])
  421. elements = [ZZ(1), ZZ(2), ZZ(3)]
  422. indices = ((0, 0), (0, 1), (2, 2))
  423. assert M.to_flat_nz() == (elements, (indices, M.shape))
  424. @pytest.mark.parametrize('DM', DMZ_all)
  425. def test_XXM_from_flat_nz(DM):
  426. T = type(DM([[0]]))
  427. elements = [ZZ(1), ZZ(2), ZZ(3)]
  428. indices = ((0, 0), (0, 1), (2, 2))
  429. data = (indices, (3, 3))
  430. result = DM([[1, 2, 0], [0, 0, 0], [0, 0, 3]])
  431. assert T.from_flat_nz(elements, data, ZZ) == result
  432. raises(DMBadInputError, lambda: T.from_flat_nz(elements, (indices, (2, 3)), ZZ))
  433. @pytest.mark.parametrize('DM', DMZ_all)
  434. def test_XXM_to_dod(DM):
  435. dod = {0: {0: ZZ(1), 2: ZZ(4)}, 1: {0: ZZ(4), 1: ZZ(5), 2: ZZ(6)}}
  436. assert DM([[1, 0, 4], [4, 5, 6]]).to_dod() == dod
  437. @pytest.mark.parametrize('DM', DMZ_all)
  438. def test_XXM_from_dod(DM):
  439. T = type(DM([[0]]))
  440. dod = {0: {0: ZZ(1), 2: ZZ(4)}, 1: {0: ZZ(4), 1: ZZ(5), 2: ZZ(6)}}
  441. assert T.from_dod(dod, (2, 3), ZZ) == DM([[1, 0, 4], [4, 5, 6]])
  442. @pytest.mark.parametrize('DM', DMZ_all)
  443. def test_XXM_to_dok(DM):
  444. dod = {(0, 0): ZZ(1), (0, 2): ZZ(4),
  445. (1, 0): ZZ(4), (1, 1): ZZ(5), (1, 2): ZZ(6)}
  446. assert DM([[1, 0, 4], [4, 5, 6]]).to_dok() == dod
  447. @pytest.mark.parametrize('DM', DMZ_all)
  448. def test_XXM_from_dok(DM):
  449. T = type(DM([[0]]))
  450. dod = {(0, 0): ZZ(1), (0, 2): ZZ(4),
  451. (1, 0): ZZ(4), (1, 1): ZZ(5), (1, 2): ZZ(6)}
  452. assert T.from_dok(dod, (2, 3), ZZ) == DM([[1, 0, 4], [4, 5, 6]])
  453. @pytest.mark.parametrize('DM', DMZ_all)
  454. def test_XXM_iter_values(DM):
  455. values = [ZZ(1), ZZ(4), ZZ(4), ZZ(5), ZZ(6)]
  456. assert sorted(DM([[1, 0, 4], [4, 5, 6]]).iter_values()) == values
  457. @pytest.mark.parametrize('DM', DMZ_all)
  458. def test_XXM_iter_items(DM):
  459. items = [((0, 0), ZZ(1)), ((0, 2), ZZ(4)),
  460. ((1, 0), ZZ(4)), ((1, 1), ZZ(5)), ((1, 2), ZZ(6))]
  461. assert sorted(DM([[1, 0, 4], [4, 5, 6]]).iter_items()) == items
  462. @pytest.mark.parametrize('DM', DMZ_all)
  463. def test_XXM_from_ddm(DM):
  464. T = type(DM([[0]]))
  465. ddm = DDM([[1, 2, 4], [4, 5, 6]], (2, 3), ZZ)
  466. assert T.from_ddm(ddm) == DM([[1, 2, 4], [4, 5, 6]])
  467. @pytest.mark.parametrize('DM', DMZ_all)
  468. def test_XXM_zeros(DM):
  469. T = type(DM([[0]]))
  470. assert T.zeros((2, 3), ZZ) == DM([[0, 0, 0], [0, 0, 0]])
  471. @pytest.mark.parametrize('DM', DMZ_all)
  472. def test_XXM_ones(DM):
  473. T = type(DM([[0]]))
  474. assert T.ones((2, 3), ZZ) == DM([[1, 1, 1], [1, 1, 1]])
  475. @pytest.mark.parametrize('DM', DMZ_all)
  476. def test_XXM_eye(DM):
  477. T = type(DM([[0]]))
  478. assert T.eye(3, ZZ) == DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
  479. assert T.eye((3, 2), ZZ) == DM([[1, 0], [0, 1], [0, 0]])
  480. @pytest.mark.parametrize('DM', DMZ_all)
  481. def test_XXM_diag(DM):
  482. T = type(DM([[0]]))
  483. assert T.diag([1, 2, 3], ZZ) == DM([[1, 0, 0], [0, 2, 0], [0, 0, 3]])
  484. @pytest.mark.parametrize('DM', DMZ_all)
  485. def test_XXM_transpose(DM):
  486. A = DM([[1, 2, 3], [4, 5, 6]])
  487. assert A.transpose() == DM([[1, 4], [2, 5], [3, 6]])
  488. @pytest.mark.parametrize('DM', DMZ_all)
  489. def test_XXM_add(DM):
  490. A = DM([[1, 2, 3], [4, 5, 6]])
  491. B = DM([[1, 2, 3], [4, 5, 6]])
  492. C = DM([[2, 4, 6], [8, 10, 12]])
  493. assert A.add(B) == C
  494. @pytest.mark.parametrize('DM', DMZ_all)
  495. def test_XXM_sub(DM):
  496. A = DM([[1, 2, 3], [4, 5, 6]])
  497. B = DM([[1, 2, 3], [4, 5, 6]])
  498. C = DM([[0, 0, 0], [0, 0, 0]])
  499. assert A.sub(B) == C
  500. @pytest.mark.parametrize('DM', DMZ_all)
  501. def test_XXM_mul(DM):
  502. A = DM([[1, 2, 3], [4, 5, 6]])
  503. b = ZZ(2)
  504. assert A.mul(b) == DM([[2, 4, 6], [8, 10, 12]])
  505. assert A.rmul(b) == DM([[2, 4, 6], [8, 10, 12]])
  506. @pytest.mark.parametrize('DM', DMZ_all)
  507. def test_XXM_matmul(DM):
  508. A = DM([[1, 2, 3], [4, 5, 6]])
  509. B = DM([[1, 2], [3, 4], [5, 6]])
  510. C = DM([[22, 28], [49, 64]])
  511. assert A.matmul(B) == C
  512. @pytest.mark.parametrize('DM', DMZ_all)
  513. def test_XXM_mul_elementwise(DM):
  514. A = DM([[1, 2, 3], [4, 5, 6]])
  515. B = DM([[1, 2, 3], [4, 5, 6]])
  516. C = DM([[1, 4, 9], [16, 25, 36]])
  517. assert A.mul_elementwise(B) == C
  518. @pytest.mark.parametrize('DM', DMZ_all)
  519. def test_XXM_neg(DM):
  520. A = DM([[1, 2, 3], [4, 5, 6]])
  521. C = DM([[-1, -2, -3], [-4, -5, -6]])
  522. assert A.neg() == C
  523. @pytest.mark.parametrize('DM', DM_all)
  524. def test_XXM_convert_to(DM):
  525. A = DM([[1, 2, 3], [4, 5, 6]], ZZ)
  526. B = DM([[1, 2, 3], [4, 5, 6]], QQ)
  527. assert A.convert_to(QQ) == B
  528. assert B.convert_to(ZZ) == A
  529. @pytest.mark.parametrize('DM', DMZ_all)
  530. def test_XXM_scc(DM):
  531. A = DM([
  532. [0, 1, 0, 0, 0, 0],
  533. [1, 0, 0, 0, 0, 0],
  534. [0, 0, 1, 0, 0, 0],
  535. [0, 0, 0, 1, 0, 1],
  536. [0, 0, 0, 0, 1, 0],
  537. [0, 0, 0, 1, 0, 1]])
  538. assert A.scc() == [[0, 1], [2], [3, 5], [4]]
  539. @pytest.mark.parametrize('DM', DMZ_all)
  540. def test_XXM_hstack(DM):
  541. A = DM([[1, 2, 3], [4, 5, 6]])
  542. B = DM([[7, 8], [9, 10]])
  543. C = DM([[1, 2, 3, 7, 8], [4, 5, 6, 9, 10]])
  544. ABC = DM([[1, 2, 3, 7, 8, 1, 2, 3, 7, 8],
  545. [4, 5, 6, 9, 10, 4, 5, 6, 9, 10]])
  546. assert A.hstack(B) == C
  547. assert A.hstack(B, C) == ABC
  548. @pytest.mark.parametrize('DM', DMZ_all)
  549. def test_XXM_vstack(DM):
  550. A = DM([[1, 2, 3], [4, 5, 6]])
  551. B = DM([[7, 8, 9]])
  552. C = DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  553. ABC = DM([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
  554. assert A.vstack(B) == C
  555. assert A.vstack(B, C) == ABC
  556. @pytest.mark.parametrize('DM', DMZ_all)
  557. def test_XXM_applyfunc(DM):
  558. A = DM([[1, 2, 3], [4, 5, 6]])
  559. B = DM([[2, 4, 6], [8, 10, 12]])
  560. assert A.applyfunc(lambda x: 2*x, ZZ) == B
  561. @pytest.mark.parametrize('DM', DMZ_all)
  562. def test_XXM_is_upper(DM):
  563. assert DM([[1, 2, 3], [0, 5, 6]]).is_upper() is True
  564. assert DM([[1, 2, 3], [4, 5, 6]]).is_upper() is False
  565. assert DM([]).is_upper() is True
  566. assert DM([[], []]).is_upper() is True
  567. @pytest.mark.parametrize('DM', DMZ_all)
  568. def test_XXM_is_lower(DM):
  569. assert DM([[1, 0, 0], [4, 5, 0]]).is_lower() is True
  570. assert DM([[1, 2, 3], [4, 5, 6]]).is_lower() is False
  571. @pytest.mark.parametrize('DM', DMZ_all)
  572. def test_XXM_is_diagonal(DM):
  573. assert DM([[1, 0, 0], [0, 5, 0]]).is_diagonal() is True
  574. assert DM([[1, 2, 3], [4, 5, 6]]).is_diagonal() is False
  575. @pytest.mark.parametrize('DM', DMZ_all)
  576. def test_XXM_diagonal(DM):
  577. assert DM([[1, 0, 0], [0, 5, 0]]).diagonal() == [1, 5]
  578. @pytest.mark.parametrize('DM', DMZ_all)
  579. def test_XXM_is_zero_matrix(DM):
  580. assert DM([[0, 0, 0], [0, 0, 0]]).is_zero_matrix() is True
  581. assert DM([[1, 0, 0], [0, 0, 0]]).is_zero_matrix() is False
  582. @pytest.mark.parametrize('DM', DMZ_all)
  583. def test_XXM_det_ZZ(DM):
  584. assert DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).det() == 0
  585. assert DM([[1, 2, 3], [4, 5, 6], [7, 8, 10]]).det() == -3
  586. @pytest.mark.parametrize('DM', DMQ_all)
  587. def test_XXM_det_QQ(DM):
  588. dM1 = DM([[(1,2), (2,3)], [(3,4), (4,5)]])
  589. assert dM1.det() == QQ(-1,10)
  590. @pytest.mark.parametrize('DM', DMQ_all)
  591. def test_XXM_inv_QQ(DM):
  592. dM1 = DM([[(1,2), (2,3)], [(3,4), (4,5)]])
  593. dM2 = DM([[(-8,1), (20,3)], [(15,2), (-5,1)]])
  594. assert dM1.inv() == dM2
  595. assert dM1.matmul(dM2) == DM([[1, 0], [0, 1]])
  596. dM3 = DM([[(1,2), (2,3)], [(1,4), (1,3)]])
  597. raises(DMNonInvertibleMatrixError, lambda: dM3.inv())
  598. dM4 = DM([[(1,2), (2,3), (3,4)], [(1,4), (1,3), (1,2)]])
  599. raises(DMNonSquareMatrixError, lambda: dM4.inv())
  600. @pytest.mark.parametrize('DM', DMZ_all)
  601. def test_XXM_inv_ZZ(DM):
  602. dM1 = DM([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
  603. # XXX: Maybe this should return a DM over QQ instead?
  604. # XXX: Handle unimodular matrices?
  605. raises(DMDomainError, lambda: dM1.inv())
  606. @pytest.mark.parametrize('DM', DMZ_all)
  607. def test_XXM_charpoly_ZZ(DM):
  608. dM1 = DM([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
  609. assert dM1.charpoly() == [1, -16, -12, 3]
  610. @pytest.mark.parametrize('DM', DMQ_all)
  611. def test_XXM_charpoly_QQ(DM):
  612. dM1 = DM([[(1,2), (2,3)], [(3,4), (4,5)]])
  613. assert dM1.charpoly() == [QQ(1,1), QQ(-13,10), QQ(-1,10)]
  614. @pytest.mark.parametrize('DM', DMZ_all)
  615. def test_XXM_lu_solve_ZZ(DM):
  616. dM1 = DM([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
  617. dM2 = DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
  618. raises(DMDomainError, lambda: dM1.lu_solve(dM2))
  619. @pytest.mark.parametrize('DM', DMQ_all)
  620. def test_XXM_lu_solve_QQ(DM):
  621. dM1 = DM([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
  622. dM2 = DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
  623. dM3 = DM([[(-2,3),(-4,3),(1,1)],[(-2,3),(11,3),(-2,1)],[(1,1),(-2,1),(1,1)]])
  624. assert dM1.lu_solve(dM2) == dM3 == dM1.inv()
  625. dM4 = DM([[1, 2, 3], [4, 5, 6]])
  626. dM5 = DM([[1, 0], [0, 1], [0, 0]])
  627. raises(DMShapeError, lambda: dM4.lu_solve(dM5))
  628. @pytest.mark.parametrize('DM', DMQ_all)
  629. def test_XXM_nullspace_QQ(DM):
  630. dM1 = DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  631. # XXX: Change the signature to just return the nullspace. Possibly
  632. # returning the rank or nullity makes sense but the list of nonpivots is
  633. # not useful.
  634. assert dM1.nullspace() == (DM([[1, -2, 1]]), [2])
  635. @pytest.mark.parametrize('DM', DMZ_all)
  636. def test_XXM_lll(DM):
  637. M = DM([[1, 2, 3], [4, 5, 20]])
  638. M_lll = DM([[1, 2, 3], [-1, -5, 5]])
  639. T = DM([[1, 0], [-5, 1]])
  640. assert M.lll() == M_lll
  641. assert M.lll_transform() == (M_lll, T)
  642. assert T.matmul(M) == M_lll
  643. @pytest.mark.parametrize('DM', DMQ_all)
  644. def test_XXM_qr_mixed_signs(DM):
  645. lol = [[QQ(1), QQ(-2)], [QQ(-3), QQ(4)]]
  646. A = DM(lol)
  647. Q, R = A.qr()
  648. assert Q.matmul(R) == A
  649. assert (Q.transpose().matmul(Q)).is_diagonal
  650. assert R.is_upper
  651. @pytest.mark.parametrize('DM', DMQ_all)
  652. def test_XXM_qr_large_matrix(DM):
  653. lol = [[QQ(i + j) for j in range(10)] for i in range(10)]
  654. A = DM(lol)
  655. Q, R = A.qr()
  656. assert Q.matmul(R) == A
  657. assert (Q.transpose().matmul(Q)).is_diagonal
  658. assert R.is_upper
  659. @pytest.mark.parametrize('DM', DMQ_all)
  660. def test_XXM_qr_identity_matrix(DM):
  661. T = type(DM([[0]]))
  662. A = T.eye(3, QQ)
  663. Q, R = A.qr()
  664. assert Q == A
  665. assert R == A
  666. assert (Q.transpose().matmul(Q)).is_diagonal
  667. assert R.is_upper
  668. assert Q.shape == (3, 3)
  669. assert R.shape == (3, 3)
  670. @pytest.mark.parametrize('DM', DMQ_all)
  671. def test_XXM_qr_square_matrix(DM):
  672. lol = [[QQ(3), QQ(1)], [QQ(4), QQ(3)]]
  673. A = DM(lol)
  674. Q, R = A.qr()
  675. assert Q.matmul(R) == A
  676. assert (Q.transpose().matmul(Q)).is_diagonal
  677. assert R.is_upper
  678. @pytest.mark.parametrize('DM', DMQ_all)
  679. def test_XXM_qr_matrix_with_zero_columns(DM):
  680. lol = [[QQ(3), QQ(0)], [QQ(4), QQ(0)]]
  681. A = DM(lol)
  682. Q, R = A.qr()
  683. assert Q.matmul(R) == A
  684. assert (Q.transpose().matmul(Q)).is_diagonal
  685. assert R.is_upper
  686. @pytest.mark.parametrize('DM', DMQ_all)
  687. def test_XXM_qr_linearly_dependent_columns(DM):
  688. lol = [[QQ(1), QQ(2)], [QQ(2), QQ(4)]]
  689. A = DM(lol)
  690. Q, R = A.qr()
  691. assert Q.matmul(R) == A
  692. assert (Q.transpose().matmul(Q)).is_diagonal
  693. assert R.is_upper
  694. @pytest.mark.parametrize('DM', DMZ_all)
  695. def test_XXM_qr_non_field(DM):
  696. lol = [[ZZ(3), ZZ(1)], [ZZ(4), ZZ(3)]]
  697. A = DM(lol)
  698. with pytest.raises(DMDomainError):
  699. A.qr()
  700. @pytest.mark.parametrize('DM', DMQ_all)
  701. def test_XXM_qr_field(DM):
  702. lol = [[QQ(3), QQ(1)], [QQ(4), QQ(3)]]
  703. A = DM(lol)
  704. Q, R = A.qr()
  705. assert Q.matmul(R) == A
  706. assert (Q.transpose().matmul(Q)).is_diagonal
  707. assert R.is_upper
  708. @pytest.mark.parametrize('DM', DMQ_all)
  709. def test_XXM_qr_tall_matrix(DM):
  710. lol = [[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(5), QQ(6)]]
  711. A = DM(lol)
  712. Q, R = A.qr()
  713. assert Q.matmul(R) == A
  714. assert (Q.transpose().matmul(Q)).is_diagonal
  715. assert R.is_upper
  716. @pytest.mark.parametrize('DM', DMQ_all)
  717. def test_XXM_qr_wide_matrix(DM):
  718. lol = [[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(5), QQ(6)]]
  719. A = DM(lol)
  720. Q, R = A.qr()
  721. assert Q.matmul(R) == A
  722. assert (Q.transpose().matmul(Q)).is_diagonal
  723. assert R.is_upper
  724. @pytest.mark.parametrize('DM', DMQ_all)
  725. def test_XXM_qr_empty_matrix_0x0(DM):
  726. T = type(DM([[0]]))
  727. A = T.zeros((0, 0), QQ)
  728. Q, R = A.qr()
  729. assert Q.matmul(R).shape == A.shape
  730. assert (Q.transpose().matmul(Q)).is_diagonal
  731. assert R.is_upper
  732. assert Q.shape == (0, 0)
  733. assert R.shape == (0, 0)
  734. @pytest.mark.parametrize('DM', DMQ_all)
  735. def test_XXM_qr_empty_matrix_2x0(DM):
  736. T = type(DM([[0]]))
  737. A = T.zeros((2, 0), QQ)
  738. Q, R = A.qr()
  739. assert Q.matmul(R).shape == A.shape
  740. assert (Q.transpose().matmul(Q)).is_diagonal
  741. assert R.is_upper
  742. assert Q.shape == (2, 0)
  743. assert R.shape == (0, 0)
  744. @pytest.mark.parametrize('DM', DMQ_all)
  745. def test_XXM_qr_empty_matrix_0x2(DM):
  746. T = type(DM([[0]]))
  747. A = T.zeros((0, 2), QQ)
  748. Q, R = A.qr()
  749. assert Q.matmul(R).shape == A.shape
  750. assert (Q.transpose().matmul(Q)).is_diagonal
  751. assert R.is_upper
  752. assert Q.shape == (0, 0)
  753. assert R.shape == (0, 2)
  754. @pytest.mark.parametrize('DM', DMZ_all)
  755. def test_XXM_fflu(DM):
  756. A = DM([[1, 2], [3, 4]])
  757. P, L, D, U = A.fflu()
  758. A_field = A.convert_to(QQ)
  759. P_field = P.convert_to(QQ)
  760. L_field = L.convert_to(QQ)
  761. D_field = D.convert_to(QQ)
  762. U_field = U.convert_to(QQ)
  763. assert P.shape == A.shape
  764. assert L.shape == A.shape
  765. assert D.shape == A.shape
  766. assert U.shape == A.shape
  767. assert P == DM([[1, 0], [0, 1]])
  768. assert L == DM([[1, 0], [3, -2]])
  769. assert D == DM([[1, 0], [0, -2]])
  770. assert U == DM([[1, 2], [0, -2]])
  771. assert L_field.matmul(D_field.inv()).matmul(U_field) == P_field.matmul(A_field)