test_decomp_cossin.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import pytest
  2. import numpy as np
  3. from numpy.random import default_rng
  4. from numpy.testing import assert_allclose
  5. from scipy import linalg
  6. from scipy.linalg.lapack import _compute_lwork
  7. from scipy.stats import ortho_group, unitary_group
  8. from scipy.linalg import cossin, get_lapack_funcs
  9. REAL_DTYPES = (np.float32, np.float64)
  10. COMPLEX_DTYPES = (np.complex64, np.complex128)
  11. DTYPES = REAL_DTYPES + COMPLEX_DTYPES
  12. @pytest.mark.parametrize('dtype_', DTYPES)
  13. @pytest.mark.parametrize('m, p, q',
  14. [
  15. (2, 1, 1),
  16. (3, 2, 1),
  17. (3, 1, 2),
  18. (4, 2, 2),
  19. (4, 1, 2),
  20. (40, 12, 20),
  21. (40, 30, 1),
  22. (40, 1, 30),
  23. (100, 50, 1),
  24. (100, 50, 50),
  25. ])
  26. @pytest.mark.parametrize('swap_sign', [True, False])
  27. def test_cossin(dtype_, m, p, q, swap_sign):
  28. rng = default_rng(1708093570726217)
  29. if dtype_ in COMPLEX_DTYPES:
  30. x = np.array(unitary_group.rvs(m, random_state=rng), dtype=dtype_)
  31. else:
  32. x = np.array(ortho_group.rvs(m, random_state=rng), dtype=dtype_)
  33. u, cs, vh = cossin(x, p, q,
  34. swap_sign=swap_sign)
  35. assert_allclose(x, u @ cs @ vh, rtol=0., atol=m*1e3*np.finfo(dtype_).eps)
  36. assert u.dtype == dtype_
  37. # Test for float32 or float 64
  38. assert cs.dtype == np.real(u).dtype
  39. assert vh.dtype == dtype_
  40. u, cs, vh = cossin([x[:p, :q], x[:p, q:], x[p:, :q], x[p:, q:]],
  41. swap_sign=swap_sign)
  42. assert_allclose(x, u @ cs @ vh, rtol=0., atol=m*1e3*np.finfo(dtype_).eps)
  43. assert u.dtype == dtype_
  44. assert cs.dtype == np.real(u).dtype
  45. assert vh.dtype == dtype_
  46. _, cs2, vh2 = cossin(x, p, q,
  47. compute_u=False,
  48. swap_sign=swap_sign)
  49. assert_allclose(cs, cs2, rtol=0., atol=10*np.finfo(dtype_).eps)
  50. assert_allclose(vh, vh2, rtol=0., atol=10*np.finfo(dtype_).eps)
  51. u2, cs2, _ = cossin(x, p, q,
  52. compute_vh=False,
  53. swap_sign=swap_sign)
  54. assert_allclose(u, u2, rtol=0., atol=10*np.finfo(dtype_).eps)
  55. assert_allclose(cs, cs2, rtol=0., atol=10*np.finfo(dtype_).eps)
  56. _, cs2, _ = cossin(x, p, q,
  57. compute_u=False,
  58. compute_vh=False,
  59. swap_sign=swap_sign)
  60. assert_allclose(cs, cs2, rtol=0., atol=10*np.finfo(dtype_).eps)
  61. def test_cossin_mixed_types():
  62. rng = default_rng(1708093736390459)
  63. x = np.array(ortho_group.rvs(4, random_state=rng), dtype=np.float64)
  64. u, cs, vh = cossin([x[:2, :2],
  65. np.array(x[:2, 2:], dtype=np.complex128),
  66. x[2:, :2],
  67. x[2:, 2:]])
  68. assert u.dtype == np.complex128
  69. assert cs.dtype == np.float64
  70. assert vh.dtype == np.complex128
  71. assert_allclose(x, u @ cs @ vh, rtol=0.,
  72. atol=1e4 * np.finfo(np.complex128).eps)
  73. def test_cossin_error_incorrect_subblocks():
  74. with pytest.raises(ValueError, match="be due to missing p, q arguments."):
  75. cossin(([1, 2], [3, 4, 5], [6, 7], [8, 9, 10]))
  76. def test_cossin_error_empty_subblocks():
  77. with pytest.raises(ValueError, match="x11.*empty"):
  78. cossin(([], [], [], []))
  79. with pytest.raises(ValueError, match="x12.*empty"):
  80. cossin(([1, 2], [], [6, 7], [8, 9, 10]))
  81. with pytest.raises(ValueError, match="x21.*empty"):
  82. cossin(([1, 2], [3, 4, 5], [], [8, 9, 10]))
  83. with pytest.raises(ValueError, match="x22.*empty"):
  84. cossin(([1, 2], [3, 4, 5], [2], []))
  85. def test_cossin_error_missing_partitioning():
  86. with pytest.raises(ValueError, match=".*exactly four arrays.* got 2"):
  87. cossin(unitary_group.rvs(2))
  88. with pytest.raises(ValueError, match=".*might be due to missing p, q"):
  89. cossin(unitary_group.rvs(4))
  90. def test_cossin_error_non_iterable():
  91. with pytest.raises(ValueError, match="containing the subblocks of X"):
  92. cossin(12j)
  93. def test_cossin_error_invalid_shape():
  94. # Invalid x12 dimensions
  95. p, q = 3, 4
  96. invalid_x12 = np.ones((p, q + 2))
  97. valid_ones = np.ones((p, q))
  98. with pytest.raises(ValueError,
  99. match=r"Invalid x12 dimensions: desired \(3, 4\), got \(3, 6\)"):
  100. cossin((valid_ones, invalid_x12, valid_ones, valid_ones))
  101. # Invalid x21 dimensions
  102. invalid_x21 = np.ones(p + 2)
  103. with pytest.raises(ValueError,
  104. match=r"Invalid x21 dimensions: desired \(3, 4\), got \(1, 5\)"):
  105. cossin((valid_ones, valid_ones, invalid_x21, valid_ones))
  106. def test_cossin_error_non_square():
  107. with pytest.raises(ValueError, match="only supports square"):
  108. cossin(np.array([[1, 2]]), 1, 1)
  109. def test_cossin_error_partitioning():
  110. x = np.array(ortho_group.rvs(4), dtype=np.float64)
  111. with pytest.raises(ValueError, match="invalid p=0.*0<p<4.*"):
  112. cossin(x, 0, 1)
  113. with pytest.raises(ValueError, match="invalid p=4.*0<p<4.*"):
  114. cossin(x, 4, 1)
  115. with pytest.raises(ValueError, match="invalid q=-2.*0<q<4.*"):
  116. cossin(x, 1, -2)
  117. with pytest.raises(ValueError, match="invalid q=5.*0<q<4.*"):
  118. cossin(x, 1, 5)
  119. @pytest.mark.parametrize("dtype_", DTYPES)
  120. def test_cossin_separate(dtype_):
  121. rng = default_rng(1708093590167096)
  122. m, p, q = 98, 37, 61
  123. pfx = 'or' if dtype_ in REAL_DTYPES else 'un'
  124. X = (ortho_group.rvs(m, random_state=rng) if pfx == 'or'
  125. else unitary_group.rvs(m, random_state=rng))
  126. X = np.array(X, dtype=dtype_)
  127. drv, dlw = get_lapack_funcs((pfx + 'csd', pfx + 'csd_lwork'), [X])
  128. lwval = _compute_lwork(dlw, m, p, q)
  129. lwvals = {'lwork': lwval} if pfx == 'or' else dict(zip(['lwork',
  130. 'lrwork'],
  131. lwval))
  132. *_, theta, u1, u2, v1t, v2t, _ = \
  133. drv(X[:p, :q], X[:p, q:], X[p:, :q], X[p:, q:], **lwvals)
  134. (u1_2, u2_2), theta2, (v1t_2, v2t_2) = cossin(X, p, q, separate=True)
  135. assert_allclose(u1_2, u1, rtol=0., atol=10*np.finfo(dtype_).eps)
  136. assert_allclose(u2_2, u2, rtol=0., atol=10*np.finfo(dtype_).eps)
  137. assert_allclose(v1t_2, v1t, rtol=0., atol=10*np.finfo(dtype_).eps)
  138. assert_allclose(v2t_2, v2t, rtol=0., atol=10*np.finfo(dtype_).eps)
  139. assert_allclose(theta2, theta, rtol=0., atol=10*np.finfo(dtype_).eps)
  140. @pytest.mark.parametrize("m", [2, 5, 10, 15, 20])
  141. @pytest.mark.parametrize("p", [1, 4, 9, 14, 19])
  142. @pytest.mark.parametrize("q", [1, 4, 9, 14, 19])
  143. @pytest.mark.parametrize("swap_sign", [True, False])
  144. def test_properties(m, p, q, swap_sign):
  145. # Test all the properties advertised in `linalg.cossin` documentation.
  146. # There may be some overlap with tests above, but this is sensitive to
  147. # the bug reported in gh-19365 and more.
  148. if (p >= m) or (q >= m):
  149. pytest.skip("`0 < p < m` and `0 < q < m` must hold")
  150. # Generate unitary input
  151. rng = np.random.default_rng(329548272348596421)
  152. X = unitary_group.rvs(m, random_state=rng)
  153. np.testing.assert_allclose(X @ X.conj().T, np.eye(m), atol=1e-15)
  154. # Perform the decomposition
  155. u0, cs0, vh0 = linalg.cossin(X, p=p, q=q, separate=True, swap_sign=swap_sign)
  156. u1, u2 = u0
  157. v1, v2 = vh0
  158. v1, v2 = v1.conj().T, v2.conj().T
  159. # "U1, U2, V1, V2 are square orthogonal/unitary matrices
  160. # of dimensions (p,p), (m-p,m-p), (q,q), and (m-q,m-q) respectively"
  161. np.testing.assert_allclose(u1 @ u1.conj().T, np.eye(p), atol=1e-13)
  162. np.testing.assert_allclose(u2 @ u2.conj().T, np.eye(m-p), atol=1e-13)
  163. np.testing.assert_allclose(v1 @ v1.conj().T, np.eye(q), atol=1e-13)
  164. np.testing.assert_allclose(v2 @ v2.conj().T, np.eye(m-q), atol=1e-13)
  165. # "and C and S are (r, r) nonnegative diagonal matrices..."
  166. C = np.diag(np.cos(cs0))
  167. S = np.diag(np.sin(cs0))
  168. # "...satisfying C^2 + S^2 = I where r = min(p, m-p, q, m-q)."
  169. r = min(p, m-p, q, m-q)
  170. np.testing.assert_allclose(C**2 + S**2, np.eye(r))
  171. # "Moreover, the rank of the identity matrices are
  172. # min(p, q) - r, min(p, m - q) - r, min(m - p, q) - r,
  173. # and min(m - p, m - q) - r respectively."
  174. I11 = np.eye(min(p, q) - r)
  175. I12 = np.eye(min(p, m - q) - r)
  176. I21 = np.eye(min(m - p, q) - r)
  177. I22 = np.eye(min(m - p, m - q) - r)
  178. # From:
  179. # ┌ ┐
  180. # │ I 0 0 │ 0 0 0 │
  181. # ┌ ┐ ┌ ┐│ 0 C 0 │ 0 -S 0 │┌ ┐*
  182. # │ X11 │ X12 │ │ U1 │ ││ 0 0 0 │ 0 0 -I ││ V1 │ │
  183. # │ ────┼──── │ = │────┼────││─────────┼─────────││────┼────│
  184. # │ X21 │ X22 │ │ │ U2 ││ 0 0 0 │ I 0 0 ││ │ V2 │
  185. # └ ┘ └ ┘│ 0 S 0 │ 0 C 0 │└ ┘
  186. # │ 0 0 I │ 0 0 0 │
  187. # └ ┘
  188. # We can see that U and V are block diagonal matrices like so:
  189. U = linalg.block_diag(u1, u2)
  190. V = linalg.block_diag(v1, v2)
  191. # And the center matrix, which we'll call Q here, must be:
  192. Q11 = np.zeros((u1.shape[1], v1.shape[0]))
  193. IC11 = linalg.block_diag(I11, C)
  194. Q11[:IC11.shape[0], :IC11.shape[1]] = IC11
  195. Q12 = np.zeros((u1.shape[1], v2.shape[0]))
  196. SI12 = linalg.block_diag(S, I12) if swap_sign else linalg.block_diag(-S, -I12)
  197. Q12[-SI12.shape[0]:, -SI12.shape[1]:] = SI12
  198. Q21 = np.zeros((u2.shape[1], v1.shape[0]))
  199. SI21 = linalg.block_diag(-S, -I21) if swap_sign else linalg.block_diag(S, I21)
  200. Q21[-SI21.shape[0]:, -SI21.shape[1]:] = SI21
  201. Q22 = np.zeros((u2.shape[1], v2.shape[0]))
  202. IC22 = linalg.block_diag(I22, C)
  203. Q22[:IC22.shape[0], :IC22.shape[1]] = IC22
  204. Q = np.block([[Q11, Q12], [Q21, Q22]])
  205. # Confirm that `cossin` decomposes `X` as shown
  206. np.testing.assert_allclose(X, U @ Q @ V.conj().T)
  207. # And check that `separate=False` agrees
  208. U0, CS0, Vh0 = linalg.cossin(X, p=p, q=q, swap_sign=swap_sign)
  209. np.testing.assert_allclose(U, U0)
  210. np.testing.assert_allclose(Q, CS0)
  211. np.testing.assert_allclose(V, Vh0.conj().T)
  212. # Confirm that `compute_u`/`compute_vh` don't affect the results
  213. kwargs = dict(p=p, q=q, swap_sign=swap_sign)
  214. # `compute_u=False`
  215. u, cs, vh = linalg.cossin(X, separate=True, compute_u=False, **kwargs)
  216. assert u[0].shape == (0, 0) # probably not ideal, but this is what it does
  217. assert u[1].shape == (0, 0)
  218. assert_allclose(cs, cs0, rtol=1e-15)
  219. assert_allclose(vh[0], vh0[0], rtol=1e-15)
  220. assert_allclose(vh[1], vh0[1], rtol=1e-15)
  221. U, CS, Vh = linalg.cossin(X, compute_u=False, **kwargs)
  222. assert U.shape == (0, 0)
  223. assert_allclose(CS, CS0, rtol=1e-15)
  224. assert_allclose(Vh, Vh0, rtol=1e-15)
  225. # `compute_vh=False`
  226. u, cs, vh = linalg.cossin(X, separate=True, compute_vh=False, **kwargs)
  227. assert_allclose(u[0], u[0], rtol=1e-15)
  228. assert_allclose(u[1], u[1], rtol=1e-15)
  229. assert_allclose(cs, cs0, rtol=1e-15)
  230. assert vh[0].shape == (0, 0)
  231. assert vh[1].shape == (0, 0)
  232. U, CS, Vh = linalg.cossin(X, compute_vh=False, **kwargs)
  233. assert_allclose(U, U0, rtol=1e-15)
  234. assert_allclose(CS, CS0, rtol=1e-15)
  235. assert Vh.shape == (0, 0)
  236. # `compute_u=False, compute_vh=False`
  237. u, cs, vh = linalg.cossin(X, separate=True, compute_u=False,
  238. compute_vh=False, **kwargs)
  239. assert u[0].shape == (0, 0)
  240. assert u[1].shape == (0, 0)
  241. assert_allclose(cs, cs0, rtol=1e-15)
  242. assert vh[0].shape == (0, 0)
  243. assert vh[1].shape == (0, 0)
  244. U, CS, Vh = linalg.cossin(X, compute_u=False, compute_vh=False, **kwargs)
  245. assert U.shape == (0, 0)
  246. assert_allclose(CS, CS0, rtol=1e-15)
  247. assert Vh.shape == (0, 0)
  248. def test_indexing_bug_gh19365():
  249. # Regression test for gh-19365, which reported a bug with `separate=False`
  250. rng = np.random.default_rng(32954827234421)
  251. m = rng.integers(50, high=100)
  252. p = rng.integers(10, 40) # always p < m
  253. q = rng.integers(m - p + 1, m - 1) # always m-p < q < m
  254. X = unitary_group.rvs(m, random_state=rng) # random unitary matrix
  255. U, D, Vt = linalg.cossin(X, p=p, q=q, separate=False)
  256. assert np.allclose(U @ D @ Vt, X)