_matfuncs_sqrtm.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. """
  2. Matrix square root for general matrices and for upper triangular matrices.
  3. This module exists to avoid cyclic imports.
  4. """
  5. __all__ = []
  6. import numpy as np
  7. # Local imports
  8. from .lapack import ztrsyl, dtrsyl
  9. class SqrtmError(np.linalg.LinAlgError):
  10. pass
  11. from ._matfuncs_sqrtm_triu import within_block_loop # noqa: E402
  12. def _sqrtm_triu(T, blocksize=64):
  13. """
  14. Matrix square root of an upper triangular matrix.
  15. This is a helper function for `sqrtm` and `logm`.
  16. Parameters
  17. ----------
  18. T : (N, N) array_like upper triangular
  19. Matrix whose square root to evaluate
  20. blocksize : int, optional
  21. If the blocksize is not degenerate with respect to the
  22. size of the input array, then use a blocked algorithm. (Default: 64)
  23. Returns
  24. -------
  25. sqrtm : (N, N) ndarray
  26. Value of the sqrt function at `T`
  27. References
  28. ----------
  29. .. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013)
  30. "Blocked Schur Algorithms for Computing the Matrix Square Root,
  31. Lecture Notes in Computer Science, 7782. pp. 171-182.
  32. """
  33. T_diag = np.diag(T)
  34. keep_it_real = np.isrealobj(T) and np.min(T_diag, initial=0.) >= 0
  35. # Cast to complex as necessary + ensure double precision
  36. if not keep_it_real:
  37. T = np.asarray(T, dtype=np.complex128, order="C")
  38. T_diag = np.asarray(T_diag, dtype=np.complex128)
  39. else:
  40. T = np.asarray(T, dtype=np.float64, order="C")
  41. T_diag = np.asarray(T_diag, dtype=np.float64)
  42. R = np.diag(np.sqrt(T_diag))
  43. # Compute the number of blocks to use; use at least one block.
  44. n, n = T.shape
  45. nblocks = max(n // blocksize, 1)
  46. # Compute the smaller of the two sizes of blocks that
  47. # we will actually use, and compute the number of large blocks.
  48. bsmall, nlarge = divmod(n, nblocks)
  49. blarge = bsmall + 1
  50. nsmall = nblocks - nlarge
  51. if nsmall * bsmall + nlarge * blarge != n:
  52. raise Exception('internal inconsistency')
  53. # Define the index range covered by each block.
  54. start_stop_pairs = []
  55. start = 0
  56. for count, size in ((nsmall, bsmall), (nlarge, blarge)):
  57. for i in range(count):
  58. start_stop_pairs.append((start, start + size))
  59. start += size
  60. # Within-block interactions (Cythonized)
  61. try:
  62. within_block_loop(R, T, start_stop_pairs, nblocks)
  63. except RuntimeError as e:
  64. raise SqrtmError(*e.args) from e
  65. # Between-block interactions (Cython would give no significant speedup)
  66. for j in range(nblocks):
  67. jstart, jstop = start_stop_pairs[j]
  68. for i in range(j-1, -1, -1):
  69. istart, istop = start_stop_pairs[i]
  70. S = T[istart:istop, jstart:jstop]
  71. if j - i > 1:
  72. S = S - R[istart:istop, istop:jstart].dot(R[istop:jstart,
  73. jstart:jstop])
  74. # Invoke LAPACK.
  75. # For more details, see the solve_sylvester implementation
  76. # and the fortran dtrsyl and ztrsyl docs.
  77. Rii = R[istart:istop, istart:istop]
  78. Rjj = R[jstart:jstop, jstart:jstop]
  79. if keep_it_real:
  80. x, scale, info = dtrsyl(Rii, Rjj, S)
  81. else:
  82. x, scale, info = ztrsyl(Rii, Rjj, S)
  83. R[istart:istop, jstart:jstop] = x * scale
  84. # Return the matrix square root.
  85. return R