_decomp_schur.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. """Schur decomposition functions."""
  2. import numpy as np
  3. from numpy import asarray_chkfinite, single, asarray, array
  4. from numpy.linalg import norm
  5. from scipy._lib._util import _apply_over_batch
  6. # Local imports.
  7. from ._misc import LinAlgError, _datacopied
  8. from .lapack import get_lapack_funcs
  9. from ._decomp import eigvals
  10. __all__ = ['schur', 'rsf2csf']
  11. _double_precision = ['i', 'l', 'd']
  12. @_apply_over_batch(('a', 2))
  13. def schur(a, output='real', lwork=None, overwrite_a=False, sort=None,
  14. check_finite=True):
  15. """
  16. Compute Schur decomposition of a matrix.
  17. The Schur decomposition is::
  18. A = Z T Z^H
  19. where Z is unitary and T is either upper-triangular, or for real
  20. Schur decomposition (output='real'), quasi-upper triangular. In
  21. the quasi-triangular form, 2x2 blocks describing complex-valued
  22. eigenvalue pairs may extrude from the diagonal.
  23. Parameters
  24. ----------
  25. a : (M, M) array_like
  26. Matrix to decompose
  27. output : {'real', 'complex'}, optional
  28. When the dtype of `a` is real, this specifies whether to compute
  29. the real or complex Schur decomposition.
  30. When the dtype of `a` is complex, this argument is ignored, and the
  31. complex Schur decomposition is computed.
  32. lwork : int, optional
  33. Work array size. If None or -1, it is automatically computed.
  34. overwrite_a : bool, optional
  35. Whether to overwrite data in a (may improve performance).
  36. sort : {None, callable, 'lhp', 'rhp', 'iuc', 'ouc'}, optional
  37. Specifies whether the upper eigenvalues should be sorted. A callable
  38. may be passed that, given an eigenvalue, returns a boolean denoting
  39. whether the eigenvalue should be sorted to the top-left (True).
  40. - If ``output='complex'`` OR the dtype of `a` is complex, the callable
  41. should have one argument: the eigenvalue expressed as a complex number.
  42. - If ``output='real'`` AND the dtype of `a` is real, the callable should have
  43. two arguments: the real and imaginary parts of the eigenvalue, respectively.
  44. Alternatively, string parameters may be used::
  45. 'lhp' Left-hand plane (real(eigenvalue) < 0.0)
  46. 'rhp' Right-hand plane (real(eigenvalue) >= 0.0)
  47. 'iuc' Inside the unit circle (abs(eigenvalue) <= 1.0)
  48. 'ouc' Outside the unit circle (abs(eigenvalue) > 1.0)
  49. Defaults to None (no sorting).
  50. check_finite : bool, optional
  51. Whether to check that the input matrix contains only finite numbers.
  52. Disabling may give a performance gain, but may result in problems
  53. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  54. Returns
  55. -------
  56. T : (M, M) ndarray
  57. Schur form of A. It is real-valued for the real Schur decomposition.
  58. Z : (M, M) ndarray
  59. A unitary Schur transformation matrix for A.
  60. It is real-valued for the real Schur decomposition.
  61. sdim : int
  62. If and only if sorting was requested, a third return value will
  63. contain the number of eigenvalues satisfying the sort condition.
  64. Note that complex conjugate pairs for which the condition is true
  65. for either eigenvalue count as 2.
  66. Raises
  67. ------
  68. LinAlgError
  69. Error raised under three conditions:
  70. 1. The algorithm failed due to a failure of the QR algorithm to
  71. compute all eigenvalues.
  72. 2. If eigenvalue sorting was requested, the eigenvalues could not be
  73. reordered due to a failure to separate eigenvalues, usually because
  74. of poor conditioning.
  75. 3. If eigenvalue sorting was requested, roundoff errors caused the
  76. leading eigenvalues to no longer satisfy the sorting condition.
  77. See Also
  78. --------
  79. rsf2csf : Convert real Schur form to complex Schur form
  80. Examples
  81. --------
  82. >>> import numpy as np
  83. >>> from scipy.linalg import schur, eigvals
  84. >>> A = np.array([[0, 2, 2], [0, 1, 2], [1, 0, 1]])
  85. >>> T, Z = schur(A)
  86. >>> T
  87. array([[ 2.65896708, 1.42440458, -1.92933439],
  88. [ 0. , -0.32948354, -0.49063704],
  89. [ 0. , 1.31178921, -0.32948354]])
  90. >>> Z
  91. array([[0.72711591, -0.60156188, 0.33079564],
  92. [0.52839428, 0.79801892, 0.28976765],
  93. [0.43829436, 0.03590414, -0.89811411]])
  94. >>> T2, Z2 = schur(A, output='complex')
  95. >>> T2
  96. array([[ 2.65896708, -1.22839825+1.32378589j, 0.42590089+1.51937378j], # may vary
  97. [ 0. , -0.32948354+0.80225456j, -0.59877807+0.56192146j],
  98. [ 0. , 0. , -0.32948354-0.80225456j]])
  99. >>> eigvals(T2)
  100. array([2.65896708, -0.32948354+0.80225456j, -0.32948354-0.80225456j]) # may vary
  101. A custom eigenvalue-sorting condition that sorts by positive imaginary part
  102. is satisfied by only one eigenvalue.
  103. >>> _, _, sdim = schur(A, output='complex', sort=lambda x: x.imag > 1e-15)
  104. >>> sdim
  105. 1
  106. When ``output='real'`` and the array `a` is real, the `sort` callable must accept
  107. the real and imaginary parts as separate arguments. Note that now the complex
  108. eigenvalues ``-0.32948354+0.80225456j`` and ``-0.32948354-0.80225456j`` will be
  109. treated as a complex conjugate pair, and according to the `sdim` documentation,
  110. complex conjugate pairs for which the condition is True for *either* eigenvalue
  111. increase `sdim` by *two*.
  112. >>> _, _, sdim = schur(A, output='real', sort=lambda x, y: y > 1e-15)
  113. >>> sdim
  114. 2
  115. """
  116. if output not in ['real', 'complex', 'r', 'c']:
  117. raise ValueError("argument must be 'real', or 'complex'")
  118. if check_finite:
  119. a1 = asarray_chkfinite(a)
  120. else:
  121. a1 = asarray(a)
  122. if np.issubdtype(a1.dtype, np.integer):
  123. a1 = asarray(a, dtype=np.dtype("long"))
  124. if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]):
  125. raise ValueError('expected square matrix')
  126. typ = a1.dtype.char
  127. if output in ['complex', 'c'] and typ not in ['F', 'D']:
  128. if typ in _double_precision:
  129. a1 = a1.astype('D')
  130. else:
  131. a1 = a1.astype('F')
  132. # accommodate empty matrix
  133. if a1.size == 0:
  134. t0, z0 = schur(np.eye(2, dtype=a1.dtype))
  135. if sort is None:
  136. return (np.empty_like(a1, dtype=t0.dtype),
  137. np.empty_like(a1, dtype=z0.dtype))
  138. else:
  139. return (np.empty_like(a1, dtype=t0.dtype),
  140. np.empty_like(a1, dtype=z0.dtype), 0)
  141. overwrite_a = overwrite_a or (_datacopied(a1, a))
  142. gees, = get_lapack_funcs(('gees',), (a1,))
  143. if lwork is None or lwork == -1:
  144. # get optimal work array
  145. result = gees(lambda x: None, a1, lwork=-1)
  146. lwork = result[-2][0].real.astype(np.int_)
  147. if sort is None:
  148. sort_t = 0
  149. def sfunction(x, y=None):
  150. return None
  151. else:
  152. sort_t = 1
  153. if callable(sort):
  154. sfunction = sort
  155. elif sort == 'lhp':
  156. def sfunction(x, y=None):
  157. return x.real < 0.0
  158. elif sort == 'rhp':
  159. def sfunction(x, y=None):
  160. return x.real >= 0.0
  161. elif sort == 'iuc':
  162. def sfunction(x, y=None):
  163. z = x if y is None else x + y*1j
  164. return abs(z) <= 1.0
  165. elif sort == 'ouc':
  166. def sfunction(x, y=None):
  167. z = x if y is None else x + y*1j
  168. return abs(z) > 1.0
  169. else:
  170. raise ValueError("'sort' parameter must either be 'None', or a "
  171. "callable, or one of ('lhp','rhp','iuc','ouc')")
  172. result = gees(sfunction, a1, lwork=lwork, overwrite_a=overwrite_a,
  173. sort_t=sort_t)
  174. info = result[-1]
  175. if info < 0:
  176. raise ValueError(f'illegal value in {-info}-th argument of internal gees')
  177. elif info == a1.shape[0] + 1:
  178. raise LinAlgError('Eigenvalues could not be separated for reordering.')
  179. elif info == a1.shape[0] + 2:
  180. raise LinAlgError('Leading eigenvalues do not satisfy sort condition.')
  181. elif info > 0:
  182. raise LinAlgError("Schur form not found. Possibly ill-conditioned.")
  183. if sort is None:
  184. return result[0], result[-3]
  185. else:
  186. return result[0], result[-3], result[1]
  187. eps = np.finfo(float).eps
  188. feps = np.finfo(single).eps
  189. _array_kind = {'b': 0, 'h': 0, 'B': 0, 'i': 0, 'l': 0,
  190. 'f': 0, 'd': 0, 'F': 1, 'D': 1}
  191. _array_precision = {'i': 1, 'l': 1, 'f': 0, 'd': 1, 'F': 0, 'D': 1}
  192. _array_type = [['f', 'd'], ['F', 'D']]
  193. def _commonType(*arrays):
  194. kind = 0
  195. precision = 0
  196. for a in arrays:
  197. t = a.dtype.char
  198. kind = max(kind, _array_kind[t])
  199. precision = max(precision, _array_precision[t])
  200. return _array_type[kind][precision]
  201. def _castCopy(type, *arrays):
  202. cast_arrays = ()
  203. for a in arrays:
  204. if a.dtype.char == type:
  205. cast_arrays = cast_arrays + (a.copy(),)
  206. else:
  207. cast_arrays = cast_arrays + (a.astype(type),)
  208. if len(cast_arrays) == 1:
  209. return cast_arrays[0]
  210. else:
  211. return cast_arrays
  212. @_apply_over_batch(('T', 2), ('Z', 2))
  213. def rsf2csf(T, Z, check_finite=True):
  214. """
  215. Convert real Schur form to complex Schur form.
  216. Convert a quasi-diagonal real-valued Schur form to the upper-triangular
  217. complex-valued Schur form.
  218. Parameters
  219. ----------
  220. T : (M, M) array_like
  221. Real Schur form of the original array
  222. Z : (M, M) array_like
  223. Schur transformation matrix
  224. check_finite : bool, optional
  225. Whether to check that the input arrays contain only finite numbers.
  226. Disabling may give a performance gain, but may result in problems
  227. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  228. Returns
  229. -------
  230. T : (M, M) ndarray
  231. Complex Schur form of the original array
  232. Z : (M, M) ndarray
  233. Schur transformation matrix corresponding to the complex form
  234. See Also
  235. --------
  236. schur : Schur decomposition of an array
  237. Examples
  238. --------
  239. >>> import numpy as np
  240. >>> from scipy.linalg import schur, rsf2csf
  241. >>> A = np.array([[0, 2, 2], [0, 1, 2], [1, 0, 1]])
  242. >>> T, Z = schur(A)
  243. >>> T
  244. array([[ 2.65896708, 1.42440458, -1.92933439],
  245. [ 0. , -0.32948354, -0.49063704],
  246. [ 0. , 1.31178921, -0.32948354]])
  247. >>> Z
  248. array([[0.72711591, -0.60156188, 0.33079564],
  249. [0.52839428, 0.79801892, 0.28976765],
  250. [0.43829436, 0.03590414, -0.89811411]])
  251. >>> T2 , Z2 = rsf2csf(T, Z)
  252. >>> T2
  253. array([[2.65896708+0.j, -1.64592781+0.743164187j, -1.21516887+1.00660462j],
  254. [0.+0.j , -0.32948354+8.02254558e-01j, -0.82115218-2.77555756e-17j],
  255. [0.+0.j , 0.+0.j, -0.32948354-0.802254558j]])
  256. >>> Z2
  257. array([[0.72711591+0.j, 0.28220393-0.31385693j, 0.51319638-0.17258824j],
  258. [0.52839428+0.j, 0.24720268+0.41635578j, -0.68079517-0.15118243j],
  259. [0.43829436+0.j, -0.76618703+0.01873251j, -0.03063006+0.46857912j]])
  260. """
  261. if check_finite:
  262. Z, T = map(asarray_chkfinite, (Z, T))
  263. else:
  264. Z, T = map(asarray, (Z, T))
  265. for ind, X in enumerate([Z, T]):
  266. if X.ndim != 2 or X.shape[0] != X.shape[1]:
  267. raise ValueError(f"Input '{'ZT'[ind]}' must be square.")
  268. if T.shape[0] != Z.shape[0]:
  269. message = f"Input array shapes must match: Z: {Z.shape} vs. T: {T.shape}"
  270. raise ValueError(message)
  271. N = T.shape[0]
  272. t = _commonType(Z, T, array([3.0], 'F'))
  273. Z, T = _castCopy(t, Z, T)
  274. for m in range(N-1, 0, -1):
  275. if abs(T[m, m-1]) > eps*(abs(T[m-1, m-1]) + abs(T[m, m])):
  276. mu = eigvals(T[m-1:m+1, m-1:m+1]) - T[m, m]
  277. r = norm([mu[0], T[m, m-1]])
  278. c = mu[0] / r
  279. s = T[m, m-1] / r
  280. G = array([[c.conj(), s], [-s, c]], dtype=t)
  281. T[m-1:m+1, m-1:] = G.dot(T[m-1:m+1, m-1:])
  282. T[:m+1, m-1:m+1] = T[:m+1, m-1:m+1].dot(G.conj().T)
  283. Z[:, m-1:m+1] = Z[:, m-1:m+1].dot(G.conj().T)
  284. T[m, m-1] = 0.0
  285. return T, Z