| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- """
- Matrix square root for general matrices and for upper triangular matrices.
- This module exists to avoid cyclic imports.
- """
- __all__ = []
- import numpy as np
- # Local imports
- from .lapack import ztrsyl, dtrsyl
- class SqrtmError(np.linalg.LinAlgError):
- pass
- from ._matfuncs_sqrtm_triu import within_block_loop # noqa: E402
- def _sqrtm_triu(T, blocksize=64):
- """
- Matrix square root of an upper triangular matrix.
- This is a helper function for `sqrtm` and `logm`.
- Parameters
- ----------
- T : (N, N) array_like upper triangular
- Matrix whose square root to evaluate
- blocksize : int, optional
- If the blocksize is not degenerate with respect to the
- size of the input array, then use a blocked algorithm. (Default: 64)
- Returns
- -------
- sqrtm : (N, N) ndarray
- Value of the sqrt function at `T`
- References
- ----------
- .. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013)
- "Blocked Schur Algorithms for Computing the Matrix Square Root,
- Lecture Notes in Computer Science, 7782. pp. 171-182.
- """
- T_diag = np.diag(T)
- keep_it_real = np.isrealobj(T) and np.min(T_diag, initial=0.) >= 0
- # Cast to complex as necessary + ensure double precision
- if not keep_it_real:
- T = np.asarray(T, dtype=np.complex128, order="C")
- T_diag = np.asarray(T_diag, dtype=np.complex128)
- else:
- T = np.asarray(T, dtype=np.float64, order="C")
- T_diag = np.asarray(T_diag, dtype=np.float64)
- R = np.diag(np.sqrt(T_diag))
- # Compute the number of blocks to use; use at least one block.
- n, n = T.shape
- nblocks = max(n // blocksize, 1)
- # Compute the smaller of the two sizes of blocks that
- # we will actually use, and compute the number of large blocks.
- bsmall, nlarge = divmod(n, nblocks)
- blarge = bsmall + 1
- nsmall = nblocks - nlarge
- if nsmall * bsmall + nlarge * blarge != n:
- raise Exception('internal inconsistency')
- # Define the index range covered by each block.
- start_stop_pairs = []
- start = 0
- for count, size in ((nsmall, bsmall), (nlarge, blarge)):
- for i in range(count):
- start_stop_pairs.append((start, start + size))
- start += size
- # Within-block interactions (Cythonized)
- try:
- within_block_loop(R, T, start_stop_pairs, nblocks)
- except RuntimeError as e:
- raise SqrtmError(*e.args) from e
- # Between-block interactions (Cython would give no significant speedup)
- for j in range(nblocks):
- jstart, jstop = start_stop_pairs[j]
- for i in range(j-1, -1, -1):
- istart, istop = start_stop_pairs[i]
- S = T[istart:istop, jstart:jstop]
- if j - i > 1:
- S = S - R[istart:istop, istop:jstart].dot(R[istop:jstart,
- jstart:jstop])
- # Invoke LAPACK.
- # For more details, see the solve_sylvester implementation
- # and the fortran dtrsyl and ztrsyl docs.
- Rii = R[istart:istop, istart:istop]
- Rjj = R[jstart:jstop, jstart:jstop]
- if keep_it_real:
- x, scale, info = dtrsyl(Rii, Rjj, S)
- else:
- x, scale, info = ztrsyl(Rii, Rjj, S)
- R[istart:istop, jstart:jstop] = x * scale
- # Return the matrix square root.
- return R
|