test_fflu.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. from sympy.polys.matrices import DomainMatrix, DM
  2. from sympy.polys.domains import ZZ, QQ
  3. from sympy import Matrix
  4. import pytest
  5. FFLU_EXAMPLES = [
  6. (
  7. 'zz_2x3',
  8. DM([[1, 2, 3], [4, 5, 6]], ZZ),
  9. DM([[1, 0], [0, 1]], ZZ),
  10. DM([[1, 0], [4, -3]], ZZ),
  11. DM([[1, 0], [0, -3]], ZZ),
  12. DM([[1, 2, 3], [0, -3, -6]], ZZ),
  13. ),
  14. (
  15. 'zz_2x2',
  16. DM([[4, 3], [6, 3]], ZZ),
  17. DM([[1, 0], [0, 1]], ZZ),
  18. DM([[1, 0], [6, -6]], ZZ),
  19. DM([[4, 0], [0, -3]], ZZ),
  20. DM([[4, 3], [0, -3]], ZZ),
  21. ),
  22. (
  23. 'zz_3x2',
  24. DM([[1, 2], [3, 4], [5, 6]], ZZ),
  25. DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ),
  26. DM([[1, 0, 0], [3, 1, 0], [5, 2, 1]], ZZ),
  27. DM([[1, 0], [0, -2]], ZZ),
  28. DM([[1, 2], [0, -2], [0, 0]], ZZ),
  29. ),
  30. (
  31. 'zz_3x3',
  32. DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]], ZZ),
  33. DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ),
  34. DM([[1, 0, 0], [4, 1, 0], [7, 2, 1]], ZZ),
  35. DM([[1, 0, 0], [0, -3, 0], [0, 0, 0]], ZZ),
  36. DM([[1, 2, 3], [0, -3, -6], [0, 0, 0]], ZZ),
  37. ),
  38. (
  39. 'zz_zero',
  40. DM([[0, 0, 0], [0, 0, 0], [0, 0, 0]], ZZ),
  41. DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ),
  42. DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ),
  43. DM([[0, 0, 0], [0, 0, 0], [0, 0, 0]], ZZ),
  44. DM([[0, 0, 0], [0, 0, 0], [0, 0, 0]], ZZ),
  45. ),
  46. (
  47. 'zz_empty',
  48. DM([], ZZ),
  49. DM([], ZZ),
  50. DM([], ZZ),
  51. DM([], ZZ),
  52. DM([], ZZ),
  53. ),
  54. (
  55. 'zz_empty_0x2',
  56. DomainMatrix([], (0, 2), ZZ),
  57. DomainMatrix([], (0, 0), ZZ),
  58. DomainMatrix([], (0, 0), ZZ),
  59. DomainMatrix([], (0, 0), ZZ),
  60. DomainMatrix([], (0, 2), ZZ)
  61. ),
  62. (
  63. 'zz_empty_2x0',
  64. DomainMatrix([[], []], (2, 0), ZZ),
  65. DomainMatrix.eye((2, 2), ZZ),
  66. DomainMatrix.eye((2, 2), ZZ),
  67. DomainMatrix.eye((2, 2), ZZ),
  68. DomainMatrix([[], []], (2, 0), ZZ)
  69. ),
  70. (
  71. 'zz_negative',
  72. DM([[-1, -2], [-3, -4]], ZZ),
  73. DM([[1, 0], [0, 1]], ZZ),
  74. DM([[-1, 0], [-3, -2]], ZZ),
  75. DM([[-1, 0], [0, 2]], ZZ),
  76. DM([[-1, -2], [0, -2]], ZZ),
  77. ),
  78. (
  79. 'zz_mixed_signs',
  80. DM([[1, -2], [-3, 4]], ZZ),
  81. DM([[1, 0], [0, 1]], ZZ),
  82. DM([[1, 0], [-3, 1]], ZZ),
  83. DM([[1, 0], [0, -2]], ZZ),
  84. DM([[1, -2], [0, -2]], ZZ),
  85. ),
  86. (
  87. 'zz_upper_triangular',
  88. DM([[1, 2, 3], [0, 4, 5], [0, 0, 6]], ZZ),
  89. DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ),
  90. DM([[1, 0, 0], [0, 4, 0], [0, 0, 24]], ZZ),
  91. DM([[1, 0, 0], [0, 4, 0], [0, 0, 96]], ZZ),
  92. DM([[1, 2, 3], [0, 4, 5], [0, 0, 24]], ZZ),
  93. ),
  94. (
  95. 'zz_lower_triangular',
  96. DM([[1, 0, 0], [2, 3, 0], [4, 5, 6]], ZZ),
  97. DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ),
  98. DM([[1, 0, 0], [2, 3, 0], [4, 5, 18]], ZZ),
  99. DM([[1, 0, 0], [0, 3, 0], [0, 0, 54]], ZZ),
  100. DM([[1, 0, 0], [0, 3, 0], [0, 0, 18]], ZZ),
  101. ),
  102. (
  103. 'zz_diagonal',
  104. DM([[2, 0, 0], [0, 3, 0], [0, 0, 4]], ZZ),
  105. DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ),
  106. DM([[2, 0, 0], [0, 6, 0], [0, 0, 24]], ZZ),
  107. DM([[2, 0, 0], [0, 12, 0], [0, 0, 144]], ZZ),
  108. DM([[2, 0, 0], [0, 6, 0], [0, 0, 24]], ZZ)
  109. ),
  110. (
  111. 'rank_deficient_3x3',
  112. DM([[1, 2, 3], [2, 4, 6], [3, 6, 9]], ZZ),
  113. DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ),
  114. DM([[1, 0, 0], [2, 1, 0], [3, 0, 1]], ZZ),
  115. DM([[1, 0, 0], [0, 0, 0], [0, 0, 0]], ZZ),
  116. DM([[1, 2, 3], [0, 0, 0], [0, 0, 0]], ZZ),
  117. ),
  118. (
  119. 'zz_1x1',
  120. DM([[5]], ZZ),
  121. DM([[1]], ZZ),
  122. DM([[5]], ZZ),
  123. DM([[5]], ZZ),
  124. DM([[5]], ZZ),
  125. ),
  126. (
  127. 'zz_nx1_2rows',
  128. DM([[81], [54]], ZZ),
  129. DM([[1, 0], [0, 1]], ZZ),
  130. DM([[81, 0], [54, 81]], ZZ),
  131. DM([[81, 0], [0, 81]], ZZ),
  132. DM([[81], [0]], ZZ),
  133. ),
  134. (
  135. 'zz_nx2_3rows',
  136. DM([[2, 7], [7, 45], [25, 84]], ZZ),
  137. DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ),
  138. DM([[2, 0, 0], [7, 82, 0], [25, 41, 41]], ZZ),
  139. DM([[2, 0, 0], [0, 82, 0], [0, 0, 41]], ZZ),
  140. DM([[2, 7], [0, 82], [0, 0]], ZZ),
  141. ),
  142. (
  143. 'zz_1x2',
  144. DM([[0, 28]], ZZ),
  145. DM([[1]], ZZ),
  146. DM([[28]], ZZ),
  147. DM([[28]], ZZ),
  148. DM([[0, 28]], ZZ)
  149. ),
  150. (
  151. 'zz_nx3_4rows',
  152. DM([[84, 30, 9], [20, 59, 13], [53, 46, 81], [63, 48, 29]], ZZ),
  153. DM([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], ZZ),
  154. DM([[84, 0, 0, 0], [20, 365904, 0, 0], [53, 303411, 303411, 0], [63, 303411, 303411, 303411]], ZZ),
  155. DM([[84, 0, 0, 0], [0, 365904, 0, 0], [0, 0, 1321658316, 0], [0, 0, 0, 303411]], ZZ),
  156. DM([[84, 30, 9], [0, 365904, 13], [0, 0, 1321658316], [0, 0, 0]], ZZ),
  157. ),
  158. (
  159. 'fflu_row_swap',
  160. DM([[0, 1, 2], [3, 4, 5], [6, 7, 8]], ZZ),
  161. DM([[0, 1, 0], [1, 0, 0], [0, 0, 1]], ZZ),
  162. DM([[3, 0, 0], [0, 3, 0], [6, -3, 1]], ZZ),
  163. DM([[3, 0, 0], [0, 9, 0], [0, 0, 3]], ZZ),
  164. DM([[3, 4, 5], [0, 3, 6], [0, 0, 0]], ZZ)
  165. ),
  166. ]
  167. def _check_fflu(A, P, L, D, U):
  168. P_field = P.to_field().to_dense()
  169. L_field = L.to_field().to_dense()
  170. D_field = D.to_field().to_dense()
  171. U_field = U.to_field().to_dense()
  172. m, n = A.shape
  173. assert P_field.shape == (m, m)
  174. assert L_field.shape == (m, m)
  175. assert D_field.shape == (m, m)
  176. assert U_field.shape == (m, n)
  177. assert L_field.is_lower
  178. assert D_field.is_diagonal
  179. di, d = D.inv_den()
  180. assert P.matmul(A).rmul(d) == L.matmul(di).matmul(U)
  181. assert U_field.is_upper
  182. def _to_DM(A, ans):
  183. if isinstance(A, DomainMatrix):
  184. return A
  185. elif isinstance(A, Matrix):
  186. return A.to_DM(ans.domain)
  187. return DomainMatrix(A.to_list(), A.shape, A.domain)
  188. def _check_fflu_result(result, A, P_ans, L_ans, D_ans, U_ans):
  189. P, L, D, U = result
  190. P = _to_DM(P, P_ans)
  191. L = _to_DM(L, L_ans)
  192. D = _to_DM(D, D_ans)
  193. U = _to_DM(U, U_ans)
  194. A = _to_DM(A, P_ans)
  195. m, n = A.shape
  196. assert P.shape == (m, m)
  197. assert L.shape == (m, m)
  198. assert D.shape == (m, m)
  199. assert U.shape == (m, n)
  200. assert L.is_lower
  201. assert D.is_diagonal
  202. di, d = D.inv_den()
  203. assert P.matmul(A).rmul(d) == L.matmul(di).matmul(U)
  204. assert U.is_upper
  205. @pytest.mark.parametrize('name, A, P_ans, L_ans, D_ans, U_ans', FFLU_EXAMPLES)
  206. def test_dm_dense_fflu(name, A, P_ans, L_ans, D_ans, U_ans):
  207. A = A.to_dense()
  208. _check_fflu_result(A.fflu(), A, P_ans, L_ans, D_ans, U_ans)
  209. @pytest.mark.parametrize('name, A, P_ans, L_ans, D_ans, U_ans', FFLU_EXAMPLES)
  210. def test_dm_sparse_fflu(name, A, P_ans, L_ans, D_ans, U_ans):
  211. A = A.to_sparse()
  212. _check_fflu_result(A.fflu(), A, P_ans, L_ans, D_ans, U_ans)
  213. @pytest.mark.parametrize('name, A, P_ans, L_ans, D_ans, U_ans', FFLU_EXAMPLES)
  214. def test_ddm_fflu(name, A, P_ans, L_ans, D_ans, U_ans):
  215. A = A.to_ddm()
  216. _check_fflu_result(A.fflu(), A, P_ans, L_ans, D_ans, U_ans)
  217. @pytest.mark.parametrize('name, A, P_ans, L_ans, D_ans, U_ans', FFLU_EXAMPLES)
  218. def test_sdm_fflu(name, A, P_ans, L_ans, D_ans, U_ans):
  219. A = A.to_sdm()
  220. _check_fflu_result(A.fflu(), A, P_ans, L_ans, D_ans, U_ans)
  221. @pytest.mark.parametrize('name, A, P_ans, L_ans, D_ans, U_ans', FFLU_EXAMPLES)
  222. def test_dfm_fflu(name, A, P_ans, L_ans, D_ans, U_ans):
  223. pytest.importorskip('flint')
  224. if A.domain not in (ZZ, QQ) and not A.domain.is_FF:
  225. pytest.skip("Domain not supported by DFM")
  226. A = A.to_dfm()
  227. _check_fflu_result(A.fflu(), A, P_ans, L_ans, D_ans, U_ans)
  228. def test_fflu_empty_matrix():
  229. A = DomainMatrix([], (0, 0), ZZ)
  230. P, L, D, U = A.fflu()
  231. assert P.shape == (0, 0)
  232. assert L.shape == (0, 0)
  233. assert D.shape == (0, 0)
  234. assert U.shape == (0, 0)
  235. def test_fflu_properties():
  236. A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ)
  237. P, L, D, U = A.fflu()
  238. assert P.shape == (2, 2)
  239. assert L.shape == (2, 2)
  240. assert D.shape == (2, 2)
  241. assert U.shape == (2, 2)
  242. assert L.is_lower
  243. assert U.is_upper
  244. assert D.is_diagonal
  245. di, d = D.inv_den()
  246. assert P.matmul(A).rmul(d) == L.matmul(di).matmul(U)
  247. def test_fflu_rank_deficient():
  248. A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(2), ZZ(4)]], (2, 2), ZZ)
  249. P, L, D, U = A.fflu()
  250. assert P.shape == (2, 2)
  251. assert L.shape == (2, 2)
  252. assert D.shape == (2, 2)
  253. assert U.shape == (2, 2)
  254. assert U.getitem_sympy(1, 1) == 0