| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397 |
- """LU decomposition functions."""
- from warnings import warn
- from numpy import asarray, asarray_chkfinite
- import numpy as np
- from itertools import product
- from scipy._lib._util import _apply_over_batch
- # Local imports
- from ._misc import _datacopied, LinAlgWarning
- from .lapack import get_lapack_funcs, _normalize_lapack_dtype
- from ._decomp_lu_cython import lu_dispatcher
- __all__ = ['lu', 'lu_solve', 'lu_factor']
- @_apply_over_batch(('a', 2))
- def lu_factor(a, overwrite_a=False, check_finite=True):
- """
- Compute pivoted LU decomposition of a matrix.
- The decomposition is::
- A = P L U
- where P is a permutation matrix, L lower triangular with unit
- diagonal elements, and U upper triangular.
- Parameters
- ----------
- a : (M, N) array_like
- Matrix to decompose
- overwrite_a : bool, optional
- Whether to overwrite data in A (may increase performance)
- check_finite : bool, optional
- Whether to check that the input matrix contains only finite numbers.
- Disabling may give a performance gain, but may result in problems
- (crashes, non-termination) if the inputs do contain infinities or NaNs.
- Returns
- -------
- lu : (M, N) ndarray
- Matrix containing U in its upper triangle, and L in its lower triangle.
- The unit diagonal elements of L are not stored.
- piv : (K,) ndarray
- Pivot indices representing the permutation matrix P:
- row i of matrix was interchanged with row piv[i].
- Of shape ``(K,)``, with ``K = min(M, N)``.
- See Also
- --------
- lu : gives lu factorization in more user-friendly format
- lu_solve : solve an equation system using the LU factorization of a matrix
- Notes
- -----
- This is a wrapper to the ``*GETRF`` routines from LAPACK. Unlike
- :func:`lu`, it outputs the L and U factors into a single array
- and returns pivot indices instead of a permutation matrix.
- While the underlying ``*GETRF`` routines return 1-based pivot indices, the
- ``piv`` array returned by ``lu_factor`` contains 0-based indices.
- Examples
- --------
- >>> import numpy as np
- >>> from scipy.linalg import lu_factor
- >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
- >>> lu, piv = lu_factor(A)
- >>> piv
- array([2, 2, 3, 3], dtype=int32)
- Convert LAPACK's ``piv`` array to NumPy index and test the permutation
- >>> def pivot_to_permutation(piv):
- ... perm = np.arange(len(piv))
- ... for i in range(len(piv)):
- ... perm[i], perm[piv[i]] = perm[piv[i]], perm[i]
- ... return perm
- ...
- >>> p_inv = pivot_to_permutation(piv)
- >>> p_inv
- array([2, 0, 3, 1])
- >>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu)
- >>> np.allclose(A[p_inv] - L @ U, np.zeros((4, 4)))
- True
- The P matrix in P L U is defined by the inverse permutation and
- can be recovered using argsort:
- >>> p = np.argsort(p_inv)
- >>> p
- array([1, 3, 0, 2])
- >>> np.allclose(A - L[p] @ U, np.zeros((4, 4)))
- True
- or alternatively:
- >>> P = np.eye(4)[p]
- >>> np.allclose(A - P @ L @ U, np.zeros((4, 4)))
- True
- """
- if check_finite:
- a1 = asarray_chkfinite(a)
- else:
- a1 = asarray(a)
- # accommodate empty arrays
- if a1.size == 0:
- lu = np.empty_like(a1)
- piv = np.arange(0, dtype=np.int32)
- return lu, piv
- overwrite_a = overwrite_a or (_datacopied(a1, a))
- getrf, = get_lapack_funcs(('getrf',), (a1,))
- lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
- if info < 0:
- raise ValueError(
- f'illegal value in {-info}th argument of internal getrf (lu_factor)'
- )
- if info > 0:
- warn(
- f"Diagonal number {info} is exactly zero. Singular matrix.",
- LinAlgWarning,
- stacklevel=2
- )
- return lu, piv
- def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
- """Solve an equation system, a x = b, given the LU factorization of a
- The documentation is written assuming array arguments are of specified
- "core" shapes. However, array argument(s) of this function may have additional
- "batch" dimensions prepended to the core shape. In this case, the array is treated
- as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
- Parameters
- ----------
- (lu, piv)
- Factorization of the coefficient matrix a, as given by lu_factor.
- In particular piv are 0-indexed pivot indices.
- b : array
- Right-hand side
- trans : {0, 1, 2}, optional
- Type of system to solve:
- ===== =========
- trans system
- ===== =========
- 0 a x = b
- 1 a^T x = b
- 2 a^H x = b
- ===== =========
- overwrite_b : bool, optional
- Whether to overwrite data in b (may increase performance)
- check_finite : bool, optional
- Whether to check that the input matrices contain only finite numbers.
- Disabling may give a performance gain, but may result in problems
- (crashes, non-termination) if the inputs do contain infinities or NaNs.
- Returns
- -------
- x : array
- Solution to the system
- See Also
- --------
- lu_factor : LU factorize a matrix
- Examples
- --------
- >>> import numpy as np
- >>> from scipy.linalg import lu_factor, lu_solve
- >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
- >>> b = np.array([1, 1, 1, 1])
- >>> lu, piv = lu_factor(A)
- >>> x = lu_solve((lu, piv), b)
- >>> np.allclose(A @ x - b, np.zeros((4,)))
- True
- """
- (lu, piv) = lu_and_piv
- return _lu_solve(lu, piv, b, trans=trans, overwrite_b=overwrite_b,
- check_finite=check_finite)
- @_apply_over_batch(('lu', 2), ('piv', 1), ('b', '1|2'))
- def _lu_solve(lu, piv, b, trans, overwrite_b, check_finite):
- if check_finite:
- b1 = asarray_chkfinite(b)
- else:
- b1 = asarray(b)
- overwrite_b = overwrite_b or _datacopied(b1, b)
- if lu.shape[0] != b1.shape[0]:
- raise ValueError(f"Shapes of lu {lu.shape} and b {b1.shape} are incompatible")
- # accommodate empty arrays
- if b1.size == 0:
- m = lu_solve((np.eye(2, dtype=lu.dtype), [0, 1]), np.ones(2, dtype=b.dtype))
- return np.empty_like(b1, dtype=m.dtype)
- getrs, = get_lapack_funcs(('getrs',), (lu, b1))
- x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
- if info == 0:
- return x
- raise ValueError(f'illegal value in {-info}th argument of internal gesv|posv')
- def lu(a, permute_l=False, overwrite_a=False, check_finite=True,
- p_indices=False):
- """
- Compute LU decomposition of a matrix with partial pivoting.
- The decomposition satisfies::
- A = P @ L @ U
- where ``P`` is a permutation matrix, ``L`` lower triangular with unit
- diagonal elements, and ``U`` upper triangular. If `permute_l` is set to
- ``True`` then ``L`` is returned already permuted and hence satisfying
- ``A = L @ U``.
- Array argument(s) of this function may have additional
- "batch" dimensions prepended to the core shape. In this case, the array is treated
- as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
- Parameters
- ----------
- a : (..., M, N) array_like
- Array to decompose
- permute_l : bool, optional
- Perform the multiplication P*L (Default: do not permute)
- overwrite_a : bool, optional
- Whether to overwrite data in a (may improve performance)
- check_finite : bool, optional
- Whether to check that the input matrix contains only finite numbers.
- Disabling may give a performance gain, but may result in problems
- (crashes, non-termination) if the inputs do contain infinities or NaNs.
- p_indices : bool, optional
- If ``True`` the permutation information is returned as row indices.
- The default is ``False`` for backwards-compatibility reasons.
- Returns
- -------
- (p, l, u) | (pl, u):
- The tuple `(p, l, u)` is returned if `permute_l` is ``False`` (default) else
- the tuple `(pl, u)` is returned, where:
- p : (..., M, M) ndarray
- Permutation arrays or vectors depending on `p_indices`.
- l : (..., M, K) ndarray
- Lower triangular or trapezoidal array with unit diagonal, where the last
- dimension is ``K = min(M, N)``.
- pl : (..., M, K) ndarray
- Permuted L matrix with last dimension being ``K = min(M, N)``.
- u : (..., K, N) ndarray
- Upper triangular or trapezoidal array.
- Notes
- -----
- Permutation matrices are costly since they are nothing but row reorder of
- ``L`` and hence indices are strongly recommended to be used instead if the
- permutation is required. The relation in the 2D case then becomes simply
- ``A = L[P, :] @ U``. In higher dimensions, it is better to use `permute_l`
- to avoid complicated indexing tricks.
- In 2D case, if one has the indices however, for some reason, the
- permutation matrix is still needed then it can be constructed by
- ``np.eye(M)[P, :]``.
- Examples
- --------
- >>> import numpy as np
- >>> from scipy.linalg import lu
- >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
- >>> p, l, u = lu(A)
- >>> np.allclose(A, p @ l @ u)
- True
- >>> p # Permutation matrix
- array([[0., 1., 0., 0.], # Row index 1
- [0., 0., 0., 1.], # Row index 3
- [1., 0., 0., 0.], # Row index 0
- [0., 0., 1., 0.]]) # Row index 2
- >>> p, _, _ = lu(A, p_indices=True)
- >>> p
- array([1, 3, 0, 2], dtype=int32) # as given by row indices above
- >>> np.allclose(A, l[p, :] @ u)
- True
- We can also use nd-arrays, for example, a demonstration with 4D array:
- >>> rng = np.random.default_rng()
- >>> A = rng.uniform(low=-4, high=4, size=[3, 2, 4, 8])
- >>> p, l, u = lu(A)
- >>> p.shape, l.shape, u.shape
- ((3, 2, 4, 4), (3, 2, 4, 4), (3, 2, 4, 8))
- >>> np.allclose(A, p @ l @ u)
- True
- >>> PL, U = lu(A, permute_l=True)
- >>> np.allclose(A, PL @ U)
- True
- """
- a1 = np.asarray_chkfinite(a) if check_finite else np.asarray(a)
- if a1.ndim < 2:
- raise ValueError('The input array must be at least two-dimensional.')
- # Also check if dtype is LAPACK compatible
- a1, overwrite_a = _normalize_lapack_dtype(a1, overwrite_a)
- *nd, m, n = a1.shape
- k = min(m, n)
- real_dchar = 'f' if a1.dtype.char in 'fF' else 'd'
- # Empty input
- if min(*a1.shape) == 0:
- if permute_l:
- PL = np.empty(shape=[*nd, m, k], dtype=a1.dtype)
- U = np.empty(shape=[*nd, k, n], dtype=a1.dtype)
- return PL, U
- else:
- P = (np.empty([*nd, 0], dtype=np.int32) if p_indices else
- np.empty([*nd, 0, 0], dtype=real_dchar))
- L = np.empty(shape=[*nd, m, k], dtype=a1.dtype)
- U = np.empty(shape=[*nd, k, n], dtype=a1.dtype)
- return P, L, U
- # Scalar case
- if a1.shape[-2:] == (1, 1):
- if permute_l:
- return np.ones_like(a1), (a1 if overwrite_a else a1.copy())
- else:
- P = (np.zeros(shape=[*nd, m], dtype=int) if p_indices
- else np.ones_like(a1))
- return P, np.ones_like(a1), (a1 if overwrite_a else a1.copy())
- # Then check overwrite permission
- if not _datacopied(a1, a): # "a" still alive through "a1"
- if not overwrite_a:
- # Data belongs to "a" so make a copy
- a1 = a1.copy(order='C')
- # else: Do nothing we'll use "a" if possible
- # else: a1 has its own data thus free to scratch
- # Then layout checks, might happen that overwrite is allowed but original
- # array was read-only or non-contiguous.
- if not (a1.flags['C_CONTIGUOUS'] and a1.flags['WRITEABLE']):
- a1 = a1.copy(order='C')
- if not nd: # 2D array
- p = np.empty(m, dtype=np.int32)
- u = np.zeros([k, k], dtype=a1.dtype)
- lu_dispatcher(a1, u, p, permute_l)
- P, L, U = (p, a1, u) if m > n else (p, u, a1)
- else: # Stacked array
- # Prepare the contiguous data holders
- P = np.empty([*nd, m], dtype=np.int32) # perm vecs
- if m > n: # Tall arrays, U will be created
- U = np.zeros([*nd, k, k], dtype=a1.dtype)
- for ind in product(*[range(x) for x in a1.shape[:-2]]):
- lu_dispatcher(a1[ind], U[ind], P[ind], permute_l)
- L = a1
- else: # Fat arrays, L will be created
- L = np.zeros([*nd, k, k], dtype=a1.dtype)
- for ind in product(*[range(x) for x in a1.shape[:-2]]):
- lu_dispatcher(a1[ind], L[ind], P[ind], permute_l)
- U = a1
- # Convert permutation vecs to permutation arrays
- # permute_l=False needed to enter here to avoid wasted efforts
- if (not p_indices) and (not permute_l):
- if nd:
- Pa = np.zeros([*nd, m, m], dtype=real_dchar)
- # An unreadable index hack - One-hot encoding for perm matrices
- nd_ix = np.ix_(*([np.arange(x) for x in nd]+[np.arange(m)]))
- Pa[(*nd_ix, P)] = 1
- P = Pa
- else: # 2D case
- Pa = np.zeros([m, m], dtype=real_dchar)
- Pa[np.arange(m), P] = 1
- P = Pa
- return (L, U) if permute_l else (P, L, U)
|