_decomp_lu.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. """LU decomposition functions."""
  2. from warnings import warn
  3. from numpy import asarray, asarray_chkfinite
  4. import numpy as np
  5. from itertools import product
  6. from scipy._lib._util import _apply_over_batch
  7. # Local imports
  8. from ._misc import _datacopied, LinAlgWarning
  9. from .lapack import get_lapack_funcs, _normalize_lapack_dtype
  10. from ._decomp_lu_cython import lu_dispatcher
  11. __all__ = ['lu', 'lu_solve', 'lu_factor']
  12. @_apply_over_batch(('a', 2))
  13. def lu_factor(a, overwrite_a=False, check_finite=True):
  14. """
  15. Compute pivoted LU decomposition of a matrix.
  16. The decomposition is::
  17. A = P L U
  18. where P is a permutation matrix, L lower triangular with unit
  19. diagonal elements, and U upper triangular.
  20. Parameters
  21. ----------
  22. a : (M, N) array_like
  23. Matrix to decompose
  24. overwrite_a : bool, optional
  25. Whether to overwrite data in A (may increase performance)
  26. check_finite : bool, optional
  27. Whether to check that the input matrix contains only finite numbers.
  28. Disabling may give a performance gain, but may result in problems
  29. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  30. Returns
  31. -------
  32. lu : (M, N) ndarray
  33. Matrix containing U in its upper triangle, and L in its lower triangle.
  34. The unit diagonal elements of L are not stored.
  35. piv : (K,) ndarray
  36. Pivot indices representing the permutation matrix P:
  37. row i of matrix was interchanged with row piv[i].
  38. Of shape ``(K,)``, with ``K = min(M, N)``.
  39. See Also
  40. --------
  41. lu : gives lu factorization in more user-friendly format
  42. lu_solve : solve an equation system using the LU factorization of a matrix
  43. Notes
  44. -----
  45. This is a wrapper to the ``*GETRF`` routines from LAPACK. Unlike
  46. :func:`lu`, it outputs the L and U factors into a single array
  47. and returns pivot indices instead of a permutation matrix.
  48. While the underlying ``*GETRF`` routines return 1-based pivot indices, the
  49. ``piv`` array returned by ``lu_factor`` contains 0-based indices.
  50. Examples
  51. --------
  52. >>> import numpy as np
  53. >>> from scipy.linalg import lu_factor
  54. >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
  55. >>> lu, piv = lu_factor(A)
  56. >>> piv
  57. array([2, 2, 3, 3], dtype=int32)
  58. Convert LAPACK's ``piv`` array to NumPy index and test the permutation
  59. >>> def pivot_to_permutation(piv):
  60. ... perm = np.arange(len(piv))
  61. ... for i in range(len(piv)):
  62. ... perm[i], perm[piv[i]] = perm[piv[i]], perm[i]
  63. ... return perm
  64. ...
  65. >>> p_inv = pivot_to_permutation(piv)
  66. >>> p_inv
  67. array([2, 0, 3, 1])
  68. >>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu)
  69. >>> np.allclose(A[p_inv] - L @ U, np.zeros((4, 4)))
  70. True
  71. The P matrix in P L U is defined by the inverse permutation and
  72. can be recovered using argsort:
  73. >>> p = np.argsort(p_inv)
  74. >>> p
  75. array([1, 3, 0, 2])
  76. >>> np.allclose(A - L[p] @ U, np.zeros((4, 4)))
  77. True
  78. or alternatively:
  79. >>> P = np.eye(4)[p]
  80. >>> np.allclose(A - P @ L @ U, np.zeros((4, 4)))
  81. True
  82. """
  83. if check_finite:
  84. a1 = asarray_chkfinite(a)
  85. else:
  86. a1 = asarray(a)
  87. # accommodate empty arrays
  88. if a1.size == 0:
  89. lu = np.empty_like(a1)
  90. piv = np.arange(0, dtype=np.int32)
  91. return lu, piv
  92. overwrite_a = overwrite_a or (_datacopied(a1, a))
  93. getrf, = get_lapack_funcs(('getrf',), (a1,))
  94. lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
  95. if info < 0:
  96. raise ValueError(
  97. f'illegal value in {-info}th argument of internal getrf (lu_factor)'
  98. )
  99. if info > 0:
  100. warn(
  101. f"Diagonal number {info} is exactly zero. Singular matrix.",
  102. LinAlgWarning,
  103. stacklevel=2
  104. )
  105. return lu, piv
  106. def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
  107. """Solve an equation system, a x = b, given the LU factorization of a
  108. The documentation is written assuming array arguments are of specified
  109. "core" shapes. However, array argument(s) of this function may have additional
  110. "batch" dimensions prepended to the core shape. In this case, the array is treated
  111. as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
  112. Parameters
  113. ----------
  114. (lu, piv)
  115. Factorization of the coefficient matrix a, as given by lu_factor.
  116. In particular piv are 0-indexed pivot indices.
  117. b : array
  118. Right-hand side
  119. trans : {0, 1, 2}, optional
  120. Type of system to solve:
  121. ===== =========
  122. trans system
  123. ===== =========
  124. 0 a x = b
  125. 1 a^T x = b
  126. 2 a^H x = b
  127. ===== =========
  128. overwrite_b : bool, optional
  129. Whether to overwrite data in b (may increase performance)
  130. check_finite : bool, optional
  131. Whether to check that the input matrices contain only finite numbers.
  132. Disabling may give a performance gain, but may result in problems
  133. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  134. Returns
  135. -------
  136. x : array
  137. Solution to the system
  138. See Also
  139. --------
  140. lu_factor : LU factorize a matrix
  141. Examples
  142. --------
  143. >>> import numpy as np
  144. >>> from scipy.linalg import lu_factor, lu_solve
  145. >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
  146. >>> b = np.array([1, 1, 1, 1])
  147. >>> lu, piv = lu_factor(A)
  148. >>> x = lu_solve((lu, piv), b)
  149. >>> np.allclose(A @ x - b, np.zeros((4,)))
  150. True
  151. """
  152. (lu, piv) = lu_and_piv
  153. return _lu_solve(lu, piv, b, trans=trans, overwrite_b=overwrite_b,
  154. check_finite=check_finite)
  155. @_apply_over_batch(('lu', 2), ('piv', 1), ('b', '1|2'))
  156. def _lu_solve(lu, piv, b, trans, overwrite_b, check_finite):
  157. if check_finite:
  158. b1 = asarray_chkfinite(b)
  159. else:
  160. b1 = asarray(b)
  161. overwrite_b = overwrite_b or _datacopied(b1, b)
  162. if lu.shape[0] != b1.shape[0]:
  163. raise ValueError(f"Shapes of lu {lu.shape} and b {b1.shape} are incompatible")
  164. # accommodate empty arrays
  165. if b1.size == 0:
  166. m = lu_solve((np.eye(2, dtype=lu.dtype), [0, 1]), np.ones(2, dtype=b.dtype))
  167. return np.empty_like(b1, dtype=m.dtype)
  168. getrs, = get_lapack_funcs(('getrs',), (lu, b1))
  169. x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
  170. if info == 0:
  171. return x
  172. raise ValueError(f'illegal value in {-info}th argument of internal gesv|posv')
  173. def lu(a, permute_l=False, overwrite_a=False, check_finite=True,
  174. p_indices=False):
  175. """
  176. Compute LU decomposition of a matrix with partial pivoting.
  177. The decomposition satisfies::
  178. A = P @ L @ U
  179. where ``P`` is a permutation matrix, ``L`` lower triangular with unit
  180. diagonal elements, and ``U`` upper triangular. If `permute_l` is set to
  181. ``True`` then ``L`` is returned already permuted and hence satisfying
  182. ``A = L @ U``.
  183. Array argument(s) of this function may have additional
  184. "batch" dimensions prepended to the core shape. In this case, the array is treated
  185. as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
  186. Parameters
  187. ----------
  188. a : (..., M, N) array_like
  189. Array to decompose
  190. permute_l : bool, optional
  191. Perform the multiplication P*L (Default: do not permute)
  192. overwrite_a : bool, optional
  193. Whether to overwrite data in a (may improve performance)
  194. check_finite : bool, optional
  195. Whether to check that the input matrix contains only finite numbers.
  196. Disabling may give a performance gain, but may result in problems
  197. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  198. p_indices : bool, optional
  199. If ``True`` the permutation information is returned as row indices.
  200. The default is ``False`` for backwards-compatibility reasons.
  201. Returns
  202. -------
  203. (p, l, u) | (pl, u):
  204. The tuple `(p, l, u)` is returned if `permute_l` is ``False`` (default) else
  205. the tuple `(pl, u)` is returned, where:
  206. p : (..., M, M) ndarray
  207. Permutation arrays or vectors depending on `p_indices`.
  208. l : (..., M, K) ndarray
  209. Lower triangular or trapezoidal array with unit diagonal, where the last
  210. dimension is ``K = min(M, N)``.
  211. pl : (..., M, K) ndarray
  212. Permuted L matrix with last dimension being ``K = min(M, N)``.
  213. u : (..., K, N) ndarray
  214. Upper triangular or trapezoidal array.
  215. Notes
  216. -----
  217. Permutation matrices are costly since they are nothing but row reorder of
  218. ``L`` and hence indices are strongly recommended to be used instead if the
  219. permutation is required. The relation in the 2D case then becomes simply
  220. ``A = L[P, :] @ U``. In higher dimensions, it is better to use `permute_l`
  221. to avoid complicated indexing tricks.
  222. In 2D case, if one has the indices however, for some reason, the
  223. permutation matrix is still needed then it can be constructed by
  224. ``np.eye(M)[P, :]``.
  225. Examples
  226. --------
  227. >>> import numpy as np
  228. >>> from scipy.linalg import lu
  229. >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
  230. >>> p, l, u = lu(A)
  231. >>> np.allclose(A, p @ l @ u)
  232. True
  233. >>> p # Permutation matrix
  234. array([[0., 1., 0., 0.], # Row index 1
  235. [0., 0., 0., 1.], # Row index 3
  236. [1., 0., 0., 0.], # Row index 0
  237. [0., 0., 1., 0.]]) # Row index 2
  238. >>> p, _, _ = lu(A, p_indices=True)
  239. >>> p
  240. array([1, 3, 0, 2], dtype=int32) # as given by row indices above
  241. >>> np.allclose(A, l[p, :] @ u)
  242. True
  243. We can also use nd-arrays, for example, a demonstration with 4D array:
  244. >>> rng = np.random.default_rng()
  245. >>> A = rng.uniform(low=-4, high=4, size=[3, 2, 4, 8])
  246. >>> p, l, u = lu(A)
  247. >>> p.shape, l.shape, u.shape
  248. ((3, 2, 4, 4), (3, 2, 4, 4), (3, 2, 4, 8))
  249. >>> np.allclose(A, p @ l @ u)
  250. True
  251. >>> PL, U = lu(A, permute_l=True)
  252. >>> np.allclose(A, PL @ U)
  253. True
  254. """
  255. a1 = np.asarray_chkfinite(a) if check_finite else np.asarray(a)
  256. if a1.ndim < 2:
  257. raise ValueError('The input array must be at least two-dimensional.')
  258. # Also check if dtype is LAPACK compatible
  259. a1, overwrite_a = _normalize_lapack_dtype(a1, overwrite_a)
  260. *nd, m, n = a1.shape
  261. k = min(m, n)
  262. real_dchar = 'f' if a1.dtype.char in 'fF' else 'd'
  263. # Empty input
  264. if min(*a1.shape) == 0:
  265. if permute_l:
  266. PL = np.empty(shape=[*nd, m, k], dtype=a1.dtype)
  267. U = np.empty(shape=[*nd, k, n], dtype=a1.dtype)
  268. return PL, U
  269. else:
  270. P = (np.empty([*nd, 0], dtype=np.int32) if p_indices else
  271. np.empty([*nd, 0, 0], dtype=real_dchar))
  272. L = np.empty(shape=[*nd, m, k], dtype=a1.dtype)
  273. U = np.empty(shape=[*nd, k, n], dtype=a1.dtype)
  274. return P, L, U
  275. # Scalar case
  276. if a1.shape[-2:] == (1, 1):
  277. if permute_l:
  278. return np.ones_like(a1), (a1 if overwrite_a else a1.copy())
  279. else:
  280. P = (np.zeros(shape=[*nd, m], dtype=int) if p_indices
  281. else np.ones_like(a1))
  282. return P, np.ones_like(a1), (a1 if overwrite_a else a1.copy())
  283. # Then check overwrite permission
  284. if not _datacopied(a1, a): # "a" still alive through "a1"
  285. if not overwrite_a:
  286. # Data belongs to "a" so make a copy
  287. a1 = a1.copy(order='C')
  288. # else: Do nothing we'll use "a" if possible
  289. # else: a1 has its own data thus free to scratch
  290. # Then layout checks, might happen that overwrite is allowed but original
  291. # array was read-only or non-contiguous.
  292. if not (a1.flags['C_CONTIGUOUS'] and a1.flags['WRITEABLE']):
  293. a1 = a1.copy(order='C')
  294. if not nd: # 2D array
  295. p = np.empty(m, dtype=np.int32)
  296. u = np.zeros([k, k], dtype=a1.dtype)
  297. lu_dispatcher(a1, u, p, permute_l)
  298. P, L, U = (p, a1, u) if m > n else (p, u, a1)
  299. else: # Stacked array
  300. # Prepare the contiguous data holders
  301. P = np.empty([*nd, m], dtype=np.int32) # perm vecs
  302. if m > n: # Tall arrays, U will be created
  303. U = np.zeros([*nd, k, k], dtype=a1.dtype)
  304. for ind in product(*[range(x) for x in a1.shape[:-2]]):
  305. lu_dispatcher(a1[ind], U[ind], P[ind], permute_l)
  306. L = a1
  307. else: # Fat arrays, L will be created
  308. L = np.zeros([*nd, k, k], dtype=a1.dtype)
  309. for ind in product(*[range(x) for x in a1.shape[:-2]]):
  310. lu_dispatcher(a1[ind], L[ind], P[ind], permute_l)
  311. U = a1
  312. # Convert permutation vecs to permutation arrays
  313. # permute_l=False needed to enter here to avoid wasted efforts
  314. if (not p_indices) and (not permute_l):
  315. if nd:
  316. Pa = np.zeros([*nd, m, m], dtype=real_dchar)
  317. # An unreadable index hack - One-hot encoding for perm matrices
  318. nd_ix = np.ix_(*([np.arange(x) for x in nd]+[np.arange(m)]))
  319. Pa[(*nd_ix, P)] = 1
  320. P = Pa
  321. else: # 2D case
  322. Pa = np.zeros([m, m], dtype=real_dchar)
  323. Pa[np.arange(m), P] = 1
  324. P = Pa
  325. return (L, U) if permute_l else (P, L, U)