test_decomp_lu.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. import pytest
  2. from pytest import raises as assert_raises
  3. import numpy as np
  4. from scipy.linalg import lu, lu_factor, lu_solve, get_lapack_funcs, solve
  5. from numpy.testing import assert_allclose, assert_array_equal, assert_equal
  6. REAL_DTYPES = [np.float32, np.float64]
  7. COMPLEX_DTYPES = [np.complex64, np.complex128]
  8. DTYPES = REAL_DTYPES + COMPLEX_DTYPES
  9. class TestLU:
  10. def setup_method(self):
  11. self.rng = np.random.default_rng(1682281250228846)
  12. def test_old_lu_smoke_tests(self):
  13. "Tests from old fortran based lu test suite"
  14. a = np.array([[1, 2, 3], [1, 2, 3], [2, 5, 6]])
  15. p, l, u = lu(a)
  16. result_lu = np.array([[2., 5., 6.], [0.5, -0.5, 0.], [0.5, 1., 0.]])
  17. assert_allclose(p, np.rot90(np.eye(3)))
  18. assert_allclose(l, np.tril(result_lu, k=-1)+np.eye(3))
  19. assert_allclose(u, np.triu(result_lu))
  20. a = np.array([[1, 2, 3], [1, 2, 3], [2, 5j, 6]])
  21. p, l, u = lu(a)
  22. result_lu = np.array([[2., 5.j, 6.], [0.5, 2-2.5j, 0.], [0.5, 1., 0.]])
  23. assert_allclose(p, np.rot90(np.eye(3)))
  24. assert_allclose(l, np.tril(result_lu, k=-1)+np.eye(3))
  25. assert_allclose(u, np.triu(result_lu))
  26. b = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  27. p, l, u = lu(b)
  28. assert_allclose(p, np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]))
  29. assert_allclose(l, np.array([[1, 0, 0], [1/7, 1, 0], [4/7, 0.5, 1]]))
  30. assert_allclose(u, np.array([[7, 8, 9], [0, 6/7, 12/7], [0, 0, 0]]),
  31. rtol=0., atol=1e-14)
  32. cb = np.array([[1.j, 2.j, 3.j], [4j, 5j, 6j], [7j, 8j, 9j]])
  33. p, l, u = lu(cb)
  34. assert_allclose(p, np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]))
  35. assert_allclose(l, np.array([[1, 0, 0], [1/7, 1, 0], [4/7, 0.5, 1]]))
  36. assert_allclose(u, np.array([[7, 8, 9], [0, 6/7, 12/7], [0, 0, 0]])*1j,
  37. rtol=0., atol=1e-14)
  38. # Rectangular matrices
  39. hrect = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 12, 12]])
  40. p, l, u = lu(hrect)
  41. assert_allclose(p, np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]))
  42. assert_allclose(l, np.array([[1, 0, 0], [1/9, 1, 0], [5/9, 0.5, 1]]))
  43. assert_allclose(u, np.array([[9, 10, 12, 12], [0, 8/9, 15/9, 24/9],
  44. [0, 0, -0.5, 0]]), rtol=0., atol=1e-14)
  45. chrect = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 12, 12]])*1.j
  46. p, l, u = lu(chrect)
  47. assert_allclose(p, np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]))
  48. assert_allclose(l, np.array([[1, 0, 0], [1/9, 1, 0], [5/9, 0.5, 1]]))
  49. assert_allclose(u, np.array([[9, 10, 12, 12], [0, 8/9, 15/9, 24/9],
  50. [0, 0, -0.5, 0]])*1j, rtol=0., atol=1e-14)
  51. vrect = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 12, 12]])
  52. p, l, u = lu(vrect)
  53. assert_allclose(p, np.eye(4)[[1, 3, 2, 0], :])
  54. assert_allclose(l, np.array([[1., 0, 0], [0.1, 1, 0], [0.7, -0.5, 1],
  55. [0.4, 0.25, 0.5]]))
  56. assert_allclose(u, np.array([[10, 12, 12],
  57. [0, 0.8, 1.8],
  58. [0, 0, 1.5]]))
  59. cvrect = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 12, 12]])*1j
  60. p, l, u = lu(cvrect)
  61. assert_allclose(p, np.eye(4)[[1, 3, 2, 0], :])
  62. assert_allclose(l, np.array([[1., 0, 0],
  63. [0.1, 1, 0],
  64. [0.7, -0.5, 1],
  65. [0.4, 0.25, 0.5]]))
  66. assert_allclose(u, np.array([[10, 12, 12],
  67. [0, 0.8, 1.8],
  68. [0, 0, 1.5]])*1j)
  69. @pytest.mark.parametrize('shape', [[2, 2], [2, 4], [4, 2], [20, 20],
  70. [20, 4], [4, 20], [3, 2, 9, 9],
  71. [2, 2, 17, 5], [2, 2, 11, 7]])
  72. def test_simple_lu_shapes_real_complex(self, shape):
  73. a = self.rng.uniform(-10., 10., size=shape)
  74. p, l, u = lu(a)
  75. assert_allclose(a, p @ l @ u)
  76. pl, u = lu(a, permute_l=True)
  77. assert_allclose(a, pl @ u)
  78. b = self.rng.uniform(-10., 10., size=shape)*1j
  79. b += self.rng.uniform(-10, 10, size=shape)
  80. pl, u = lu(b, permute_l=True)
  81. assert_allclose(b, pl @ u)
  82. @pytest.mark.parametrize('shape', [[2, 2], [2, 4], [4, 2], [20, 20],
  83. [20, 4], [4, 20]])
  84. def test_simple_lu_shapes_real_complex_2d_indices(self, shape):
  85. a = self.rng.uniform(-10., 10., size=shape)
  86. p, l, u = lu(a, p_indices=True)
  87. assert_allclose(a, l[p, :] @ u)
  88. def test_1by1_input_output(self):
  89. a = self.rng.random([4, 5, 1, 1], dtype=np.float32)
  90. p, l, u = lu(a, p_indices=True)
  91. assert_allclose(p, np.zeros(shape=(4, 5, 1), dtype=int))
  92. assert_allclose(l, np.ones(shape=(4, 5, 1, 1), dtype=np.float32))
  93. assert_allclose(u, a)
  94. a = self.rng.random([4, 5, 1, 1], dtype=np.float32)
  95. p, l, u = lu(a)
  96. assert_allclose(p, np.ones(shape=(4, 5, 1, 1), dtype=np.float32))
  97. assert_allclose(l, np.ones(shape=(4, 5, 1, 1), dtype=np.float32))
  98. assert_allclose(u, a)
  99. pl, u = lu(a, permute_l=True)
  100. assert_allclose(pl, np.ones(shape=(4, 5, 1, 1), dtype=np.float32))
  101. assert_allclose(u, a)
  102. a = self.rng.random([4, 5, 1, 1], dtype=np.float32)*np.complex64(1.j)
  103. p, l, u = lu(a)
  104. assert_allclose(p, np.ones(shape=(4, 5, 1, 1), dtype=np.complex64))
  105. assert_allclose(l, np.ones(shape=(4, 5, 1, 1), dtype=np.complex64))
  106. assert_allclose(u, a)
  107. def test_empty_edge_cases(self):
  108. a = np.empty([0, 0])
  109. p, l, u = lu(a)
  110. assert_allclose(p, np.empty(shape=(0, 0), dtype=np.float64))
  111. assert_allclose(l, np.empty(shape=(0, 0), dtype=np.float64))
  112. assert_allclose(u, np.empty(shape=(0, 0), dtype=np.float64))
  113. a = np.empty([0, 3], dtype=np.float16)
  114. p, l, u = lu(a)
  115. assert_allclose(p, np.empty(shape=(0, 0), dtype=np.float32))
  116. assert_allclose(l, np.empty(shape=(0, 0), dtype=np.float32))
  117. assert_allclose(u, np.empty(shape=(0, 3), dtype=np.float32))
  118. a = np.empty([3, 0], dtype=np.complex64)
  119. p, l, u = lu(a)
  120. assert_allclose(p, np.empty(shape=(0, 0), dtype=np.float32))
  121. assert_allclose(l, np.empty(shape=(3, 0), dtype=np.complex64))
  122. assert_allclose(u, np.empty(shape=(0, 0), dtype=np.complex64))
  123. p, l, u = lu(a, p_indices=True)
  124. assert_allclose(p, np.empty(shape=(0,), dtype=int))
  125. assert_allclose(l, np.empty(shape=(3, 0), dtype=np.complex64))
  126. assert_allclose(u, np.empty(shape=(0, 0), dtype=np.complex64))
  127. pl, u = lu(a, permute_l=True)
  128. assert_allclose(pl, np.empty(shape=(3, 0), dtype=np.complex64))
  129. assert_allclose(u, np.empty(shape=(0, 0), dtype=np.complex64))
  130. a = np.empty([3, 0, 0], dtype=np.complex64)
  131. p, l, u = lu(a)
  132. assert_allclose(p, np.empty(shape=(3, 0, 0), dtype=np.float32))
  133. assert_allclose(l, np.empty(shape=(3, 0, 0), dtype=np.complex64))
  134. assert_allclose(u, np.empty(shape=(3, 0, 0), dtype=np.complex64))
  135. a = np.empty([0, 0, 3])
  136. p, l, u = lu(a)
  137. assert_allclose(p, np.empty(shape=(0, 0, 0)))
  138. assert_allclose(l, np.empty(shape=(0, 0, 0)))
  139. assert_allclose(u, np.empty(shape=(0, 0, 3)))
  140. with assert_raises(ValueError, match='at least two-dimensional'):
  141. lu(np.array([]))
  142. a = np.array([[]])
  143. p, l, u = lu(a)
  144. assert_allclose(p, np.empty(shape=(0, 0)))
  145. assert_allclose(l, np.empty(shape=(1, 0)))
  146. assert_allclose(u, np.empty(shape=(0, 0)))
  147. a = np.array([[[]]])
  148. p, l, u = lu(a)
  149. assert_allclose(p, np.empty(shape=(1, 0, 0)))
  150. assert_allclose(l, np.empty(shape=(1, 1, 0)))
  151. assert_allclose(u, np.empty(shape=(1, 0, 0)))
  152. class TestLUFactor:
  153. def setup_method(self):
  154. self.rng = np.random.default_rng(1682281250228846)
  155. self.a = np.array([[1, 2, 3], [1, 2, 3], [2, 5, 6]])
  156. self.ca = np.array([[1, 2, 3], [1, 2, 3], [2, 5j, 6]])
  157. # Those matrices are more robust to detect problems in permutation
  158. # matrices than the ones above
  159. self.b = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  160. self.cb = np.array([[1j, 2j, 3j], [4j, 5j, 6j], [7j, 8j, 9j]])
  161. # Rectangular matrices
  162. self.hrect = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 12, 12]])
  163. self.chrect = np.array([[1, 2, 3, 4], [5, 6, 7, 8],
  164. [9, 10, 12, 12]]) * 1.j
  165. self.vrect = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 12, 12]])
  166. self.cvrect = 1.j * np.array([[1, 2, 3],
  167. [4, 5, 6],
  168. [7, 8, 9],
  169. [10, 12, 12]])
  170. # Medium sizes matrices
  171. self.med = self.rng.random((30, 40))
  172. self.cmed = self.rng.random((30, 40)) + 1.j*self.rng.random((30, 40))
  173. def _test_common_lu_factor(self, data):
  174. l_and_u1, piv1 = lu_factor(data)
  175. (getrf,) = get_lapack_funcs(("getrf",), (data,))
  176. l_and_u2, piv2, _ = getrf(data, overwrite_a=False)
  177. assert_allclose(l_and_u1, l_and_u2)
  178. assert_allclose(piv1, piv2)
  179. # Simple tests.
  180. # For lu_factor gives a LinAlgWarning because these matrices are singular
  181. def test_hrectangular(self):
  182. self._test_common_lu_factor(self.hrect)
  183. def test_vrectangular(self):
  184. self._test_common_lu_factor(self.vrect)
  185. def test_hrectangular_complex(self):
  186. self._test_common_lu_factor(self.chrect)
  187. def test_vrectangular_complex(self):
  188. self._test_common_lu_factor(self.cvrect)
  189. # Bigger matrices
  190. def test_medium1(self):
  191. """Check lu decomposition on medium size, rectangular matrix."""
  192. self._test_common_lu_factor(self.med)
  193. def test_medium1_complex(self):
  194. """Check lu decomposition on medium size, rectangular matrix."""
  195. self._test_common_lu_factor(self.cmed)
  196. def test_check_finite(self):
  197. p, l, u = lu(self.a, check_finite=False)
  198. assert_allclose(p @ l @ u, self.a)
  199. def test_simple_known(self):
  200. # Ticket #1458
  201. for order in ['C', 'F']:
  202. A = np.array([[2, 1], [0, 1.]], order=order)
  203. LU, P = lu_factor(A)
  204. assert_allclose(LU, np.array([[2, 1], [0, 1]]))
  205. assert_array_equal(P, np.array([0, 1]))
  206. @pytest.mark.parametrize("m", [0, 1, 2])
  207. @pytest.mark.parametrize("n", [0, 1, 2])
  208. @pytest.mark.parametrize('dtype', DTYPES)
  209. def test_shape_dtype(self, m, n, dtype):
  210. k = min(m, n)
  211. a = np.eye(m, n, dtype=dtype)
  212. lu, p = lu_factor(a)
  213. assert_equal(lu.shape, (m, n))
  214. assert_equal(lu.dtype, dtype)
  215. assert_equal(p.shape, (k,))
  216. assert_equal(p.dtype, np.int32)
  217. @pytest.mark.parametrize(("m", "n"), [(0, 0), (0, 2), (2, 0)])
  218. def test_empty(self, m, n):
  219. a = np.zeros((m, n))
  220. lu, p = lu_factor(a)
  221. assert_allclose(lu, np.empty((m, n)))
  222. assert_allclose(p, np.arange(0))
  223. class TestLUSolve:
  224. def setup_method(self):
  225. self.rng = np.random.default_rng(1682281250228846)
  226. def test_lu(self):
  227. a0 = self.rng.random((10, 10))
  228. b = self.rng.random((10,))
  229. for order in ['C', 'F']:
  230. a = np.array(a0, order=order)
  231. x1 = solve(a, b)
  232. lu_a = lu_factor(a)
  233. x2 = lu_solve(lu_a, b)
  234. assert_allclose(x1, x2)
  235. def test_check_finite(self):
  236. a = self.rng.random((10, 10))
  237. b = self.rng.random((10,))
  238. x1 = solve(a, b)
  239. lu_a = lu_factor(a, check_finite=False)
  240. x2 = lu_solve(lu_a, b, check_finite=False)
  241. assert_allclose(x1, x2)
  242. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  243. @pytest.mark.parametrize('dt_b', [int, float, np.float32, complex, np.complex64])
  244. def test_empty(self, dt, dt_b):
  245. lu_and_piv = (np.empty((0, 0), dtype=dt), np.array([]))
  246. b = np.asarray([], dtype=dt_b)
  247. x = lu_solve(lu_and_piv, b)
  248. assert x.shape == (0,)
  249. m = lu_solve((np.eye(2, dtype=dt), [0, 1]), np.ones(2, dtype=dt_b))
  250. assert x.dtype == m.dtype
  251. b = np.empty((0, 0), dtype=dt_b)
  252. x = lu_solve(lu_and_piv, b)
  253. assert x.shape == (0, 0)
  254. assert x.dtype == m.dtype