_decomp_cossin.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. from collections.abc import Iterable
  2. import numpy as np
  3. from scipy._lib._util import _asarray_validated, _apply_over_batch
  4. from scipy.linalg import block_diag, LinAlgError
  5. from .lapack import _compute_lwork, get_lapack_funcs
  6. __all__ = ['cossin']
  7. def cossin(X, p=None, q=None, separate=False,
  8. swap_sign=False, compute_u=True, compute_vh=True):
  9. """
  10. Compute the cosine-sine (CS) decomposition of an orthogonal/unitary matrix.
  11. X is an ``(m, m)`` orthogonal/unitary matrix, partitioned as the following
  12. where upper left block has the shape of ``(p, q)``::
  13. ┌ ┐
  14. │ I 0 0 │ 0 0 0 │
  15. ┌ ┐ ┌ ┐│ 0 C 0 │ 0 -S 0 │┌ ┐*
  16. │ X11 │ X12 │ │ U1 │ ││ 0 0 0 │ 0 0 -I ││ V1 │ │
  17. │ ────┼──── │ = │────┼────││─────────┼─────────││────┼────│
  18. │ X21 │ X22 │ │ │ U2 ││ 0 0 0 │ I 0 0 ││ │ V2 │
  19. └ ┘ └ ┘│ 0 S 0 │ 0 C 0 │└ ┘
  20. │ 0 0 I │ 0 0 0 │
  21. └ ┘
  22. ``U1``, ``U2``, ``V1``, ``V2`` are square orthogonal/unitary matrices of
  23. dimensions ``(p,p)``, ``(m-p,m-p)``, ``(q,q)``, and ``(m-q,m-q)``
  24. respectively, and ``C`` and ``S`` are ``(r, r)`` nonnegative diagonal
  25. matrices satisfying ``C^2 + S^2 = I`` where ``r = min(p, m-p, q, m-q)``.
  26. Moreover, the rank of the identity matrices are ``min(p, q) - r``,
  27. ``min(p, m - q) - r``, ``min(m - p, q) - r``, and ``min(m - p, m - q) - r``
  28. respectively.
  29. X can be supplied either by itself and block specifications p, q or its
  30. subblocks in an iterable from which the shapes would be derived. See the
  31. examples below.
  32. The documentation is written assuming array arguments are of specified
  33. "core" shapes. However, array argument(s) of this function may have additional
  34. "batch" dimensions prepended to the core shape. In this case, the array is treated
  35. as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
  36. Parameters
  37. ----------
  38. X : array_like, iterable
  39. complex unitary or real orthogonal matrix to be decomposed, or iterable
  40. of subblocks ``X11``, ``X12``, ``X21``, ``X22``, when ``p``, ``q`` are
  41. omitted.
  42. p : int, optional
  43. Number of rows of the upper left block ``X11``, used only when X is
  44. given as an array.
  45. q : int, optional
  46. Number of columns of the upper left block ``X11``, used only when X is
  47. given as an array.
  48. separate : bool, optional
  49. if ``True``, the low level components are returned instead of the
  50. matrix factors, i.e. ``(u1,u2)``, ``theta``, ``(v1h,v2h)`` instead of
  51. ``u``, ``cs``, ``vh``.
  52. swap_sign : bool, optional
  53. if ``True``, the ``-S``, ``-I`` block will be the bottom left,
  54. otherwise (by default) they will be in the upper right block.
  55. compute_u : bool, optional
  56. if ``False``, ``u`` won't be computed and an empty array is returned.
  57. compute_vh : bool, optional
  58. if ``False``, ``vh`` won't be computed and an empty array is returned.
  59. Returns
  60. -------
  61. u : ndarray
  62. When ``compute_u=True``, contains the block diagonal orthogonal/unitary
  63. matrix consisting of the blocks ``U1`` (``p`` x ``p``) and ``U2``
  64. (``m-p`` x ``m-p``) orthogonal/unitary matrices. If ``separate=True``,
  65. this contains the tuple of ``(U1, U2)``.
  66. cs : ndarray
  67. The cosine-sine factor with the structure described above.
  68. If ``separate=True``, this contains the ``theta`` array containing the
  69. angles in radians.
  70. vh : ndarray
  71. When ``compute_vh=True`, contains the block diagonal orthogonal/unitary
  72. matrix consisting of the blocks ``V1H`` (``q`` x ``q``) and ``V2H``
  73. (``m-q`` x ``m-q``) orthogonal/unitary matrices. If ``separate=True``,
  74. this contains the tuple of ``(V1H, V2H)``.
  75. References
  76. ----------
  77. .. [1] Brian D. Sutton. Computing the complete CS decomposition. Numer.
  78. Algorithms, 50(1):33-65, 2009.
  79. Examples
  80. --------
  81. >>> import numpy as np
  82. >>> from scipy.linalg import cossin
  83. >>> from scipy.stats import unitary_group
  84. >>> x = unitary_group.rvs(4)
  85. >>> u, cs, vdh = cossin(x, p=2, q=2)
  86. >>> np.allclose(x, u @ cs @ vdh)
  87. True
  88. Same can be entered via subblocks without the need of ``p`` and ``q``. Also
  89. let's skip the computation of ``u``
  90. >>> ue, cs, vdh = cossin((x[:2, :2], x[:2, 2:], x[2:, :2], x[2:, 2:]),
  91. ... compute_u=False)
  92. >>> print(ue)
  93. []
  94. >>> np.allclose(x, u @ cs @ vdh)
  95. True
  96. """
  97. if p or q:
  98. p = 1 if p is None else int(p)
  99. q = 1 if q is None else int(q)
  100. X = _asarray_validated(X, check_finite=True)
  101. if not np.equal(*X.shape[-2:]):
  102. raise ValueError("Cosine Sine decomposition only supports square"
  103. f" matrices, got {X.shape[-2:]}")
  104. m = X.shape[-2]
  105. if p >= m or p <= 0:
  106. raise ValueError(f"invalid p={p}, 0<p<{X.shape[-2]} must hold")
  107. if q >= m or q <= 0:
  108. raise ValueError(f"invalid q={q}, 0<q<{X.shape[-2]} must hold")
  109. x11, x12, x21, x22 = (X[..., :p, :q], X[..., :p, q:],
  110. X[..., p:, :q], X[..., p:, q:])
  111. elif not isinstance(X, Iterable):
  112. raise ValueError("When p and q are None, X must be an Iterable"
  113. " containing the subblocks of X")
  114. else:
  115. if len(X) != 4:
  116. raise ValueError("When p and q are None, exactly four arrays"
  117. f" should be in X, got {len(X)}")
  118. x11, x12, x21, x22 = (np.atleast_2d(x) for x in X)
  119. return _cossin(x11, x12, x21, x22, separate=separate, swap_sign=swap_sign,
  120. compute_u=compute_u, compute_vh=compute_vh)
  121. @_apply_over_batch(('x11', 2), ('x12', 2), ('x21', 2), ('x22', 2))
  122. def _cossin(x11, x12, x21, x22, separate, swap_sign, compute_u, compute_vh):
  123. for name, block in zip(["x11", "x12", "x21", "x22"],
  124. [x11, x12, x21, x22]):
  125. if block.shape[1] == 0:
  126. raise ValueError(f"{name} can't be empty")
  127. p, q = x11.shape
  128. mmp, mmq = x22.shape
  129. if x12.shape != (p, mmq):
  130. raise ValueError(f"Invalid x12 dimensions: desired {(p, mmq)}, "
  131. f"got {x12.shape}")
  132. if x21.shape != (mmp, q):
  133. raise ValueError(f"Invalid x21 dimensions: desired {(mmp, q)}, "
  134. f"got {x21.shape}")
  135. if p + mmp != q + mmq:
  136. raise ValueError("The subblocks have compatible sizes but "
  137. "don't form a square array (instead they form a"
  138. f" {p + mmp}x{q + mmq} array). This might be "
  139. "due to missing p, q arguments.")
  140. m = p + mmp
  141. cplx = any([np.iscomplexobj(x) for x in [x11, x12, x21, x22]])
  142. driver = "uncsd" if cplx else "orcsd"
  143. csd, csd_lwork = get_lapack_funcs([driver, driver + "_lwork"],
  144. [x11, x12, x21, x22])
  145. lwork = _compute_lwork(csd_lwork, m=m, p=p, q=q)
  146. lwork_args = ({'lwork': lwork[0], 'lrwork': lwork[1]} if cplx else
  147. {'lwork': lwork})
  148. *_, theta, u1, u2, v1h, v2h, info = csd(x11=x11, x12=x12, x21=x21, x22=x22,
  149. compute_u1=compute_u,
  150. compute_u2=compute_u,
  151. compute_v1t=compute_vh,
  152. compute_v2t=compute_vh,
  153. trans=False, signs=swap_sign,
  154. **lwork_args)
  155. method_name = csd.typecode + driver
  156. if info < 0:
  157. raise ValueError(f'illegal value in argument {-info} '
  158. f'of internal {method_name}')
  159. if info > 0:
  160. raise LinAlgError(f"{method_name} did not converge: {info}")
  161. if separate:
  162. return (u1, u2), theta, (v1h, v2h)
  163. U = block_diag(u1, u2)
  164. VDH = block_diag(v1h, v2h)
  165. # Construct the middle factor CS
  166. c = np.diag(np.cos(theta))
  167. s = np.diag(np.sin(theta))
  168. r = min(p, q, m - p, m - q)
  169. n11 = min(p, q) - r
  170. n12 = min(p, m - q) - r
  171. n21 = min(m - p, q) - r
  172. n22 = min(m - p, m - q) - r
  173. Id = np.eye(np.max([n11, n12, n21, n22, r]), dtype=theta.dtype)
  174. CS = np.zeros((m, m), dtype=theta.dtype)
  175. CS[:n11, :n11] = Id[:n11, :n11]
  176. xs = n11 + r
  177. xe = n11 + r + n12
  178. ys = n11 + n21 + n22 + 2 * r
  179. ye = n11 + n21 + n22 + 2 * r + n12
  180. CS[xs: xe, ys:ye] = Id[:n12, :n12] if swap_sign else -Id[:n12, :n12]
  181. xs = p + n22 + r
  182. xe = p + n22 + r + + n21
  183. ys = n11 + r
  184. ye = n11 + r + n21
  185. CS[xs:xe, ys:ye] = -Id[:n21, :n21] if swap_sign else Id[:n21, :n21]
  186. CS[p:p + n22, q:q + n22] = Id[:n22, :n22]
  187. CS[n11:n11 + r, n11:n11 + r] = c
  188. CS[p + n22:p + n22 + r, n11 + r + n21 + n22:2 * r + n11 + n21 + n22] = c
  189. xs = n11
  190. xe = n11 + r
  191. ys = n11 + n21 + n22 + r
  192. ye = n11 + n21 + n22 + 2 * r
  193. CS[xs:xe, ys:ye] = s if swap_sign else -s
  194. CS[p + n22:p + n22 + r, n11:n11 + r] = -s if swap_sign else s
  195. return U, CS, VDH