test_propack.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import os
  2. import pytest
  3. import numpy as np
  4. from numpy.testing import assert_allclose
  5. from pytest import raises as assert_raises
  6. from scipy.sparse.linalg._svdp import _svdp
  7. from scipy.sparse import csr_array, csc_array
  8. # dtype_flavour to tolerance
  9. TOLS = {
  10. np.float32: 1e-4,
  11. np.float64: 1e-8,
  12. np.complex64: 1e-4,
  13. np.complex128: 1e-8,
  14. }
  15. def is_complex_type(dtype):
  16. return np.dtype(dtype).kind == "c"
  17. _dtypes = []
  18. for dtype_flavour in TOLS.keys():
  19. marks = []
  20. if is_complex_type(dtype_flavour):
  21. marks = [pytest.mark.slow]
  22. _dtypes.append(pytest.param(dtype_flavour, marks=marks,
  23. id=dtype_flavour.__name__))
  24. _dtypes = tuple(_dtypes) # type: ignore[assignment]
  25. # The test function here is adapted from the original Fortran PROPACK tests.
  26. # It is not very robust to arbitrary seeding since partial reorthogonalization
  27. # does not have a predictable upperbound on the number of iterations.
  28. def check_svdp(n, m, constructor, dtype, k, irl_mode, which, f=0.8, rng=None):
  29. tol = TOLS[dtype]
  30. if rng is None:
  31. rng = np.random.default_rng(0)
  32. # Legacy clamp for the generator
  33. rng2 = np.random.default_rng(0)
  34. if is_complex_type(dtype):
  35. M = (- 5 + 10 * rng2.uniform(size=[n, m])
  36. - 5j + 10j * rng2.uniform(size=[n, m])).astype(dtype)
  37. else:
  38. M = (-5 + 10 * rng2.uniform(size=[n, m])).astype(dtype)
  39. M[M.real > 10 * f - 5] = 0
  40. Msp = constructor(M)
  41. u1, sigma1, vt1 = np.linalg.svd(M, full_matrices=False)
  42. u2, sigma2, vt2, _ = _svdp(Msp, k=k,which=which, irl_mode=irl_mode,
  43. tol=tol, rng=rng)
  44. # check the which
  45. if which.upper() == 'SM':
  46. u1 = np.roll(u1, k, 1)
  47. vt1 = np.roll(vt1, k, 0)
  48. sigma1 = np.roll(sigma1, k)
  49. # check that singular values agree
  50. assert_allclose(sigma1[:k], sigma2, rtol=tol, atol=tol)
  51. # check that singular vectors are orthogonal
  52. assert_allclose(np.abs(u1.conj().T @ u2), np.eye(n, k), rtol=tol, atol=tol)
  53. assert_allclose(np.abs(vt1.conj() @ vt2.T), np.eye(n, k), rtol=tol, atol=tol)
  54. @pytest.mark.parametrize('ctor', (np.array, csr_array, csc_array))
  55. @pytest.mark.parametrize('dtype', [np.float32, np.float64,
  56. np.complex64, np.complex128])
  57. @pytest.mark.parametrize('irl', (True, False))
  58. @pytest.mark.parametrize('which', ('LM', 'SM'))
  59. def test_svdp(ctor, dtype, irl, which):
  60. rng = np.random.default_rng(1757937293955503)
  61. n, m, k = 10, 20, 3
  62. if which == 'SM' and not irl:
  63. message = "`which`='SM' requires irl_mode=True"
  64. with assert_raises(ValueError, match=message):
  65. check_svdp(n, m, ctor, dtype, k, irl, which, rng=rng)
  66. else:
  67. check_svdp(n, m, ctor, dtype, k, irl, which, rng=rng)
  68. @pytest.mark.xslow
  69. @pytest.mark.parametrize('dtype', _dtypes)
  70. @pytest.mark.parametrize('irl', (False, True))
  71. def test_examples(dtype, irl):
  72. # Note: atol for complex64 bumped from 1e-4 to 1e-3 due to test failures
  73. # with BLIS, Netlib, and MKL+AVX512 - see
  74. # https://github.com/conda-forge/scipy-feedstock/pull/198#issuecomment-999180432
  75. atol = {
  76. np.float32: 1.3e-4,
  77. np.float64: 1e-9,
  78. np.complex64: 1e-3,
  79. np.complex128: 1e-9,
  80. }[dtype]
  81. path_prefix = os.path.dirname(__file__)
  82. # Test matrices from `illc1850.coord` and `mhd1280b.cua` distributed with
  83. # PROPACK 2.1: http://sun.stanford.edu/~rmunk/PROPACK/
  84. relative_path = "propack_test_data.npz"
  85. filename = os.path.join(path_prefix, relative_path)
  86. with np.load(filename, allow_pickle=True) as data:
  87. if is_complex_type(dtype):
  88. A = data['A_complex'].item().astype(dtype)
  89. else:
  90. A = data['A_real'].item().astype(dtype)
  91. k = 200
  92. u, s, vh, _ = _svdp(A, k, irl_mode=irl, rng=np.random.default_rng(0))
  93. # complex example matrix has many repeated singular values, so check only
  94. # beginning non-repeated singular vectors to avoid permutations
  95. sv_check = 27 if is_complex_type(dtype) else k
  96. u = u[:, :sv_check]
  97. vh = vh[:sv_check, :]
  98. s = s[:sv_check]
  99. # Check orthogonality of singular vectors
  100. assert_allclose(np.eye(u.shape[1]), u.conj().T @ u, atol=atol)
  101. assert_allclose(np.eye(vh.shape[0]), vh @ vh.conj().T, atol=atol)
  102. # Ensure the norm of the difference between the np.linalg.svd and
  103. # PROPACK reconstructed matrices is small
  104. u3, s3, vh3 = np.linalg.svd(A.todense())
  105. u3 = u3[:, :sv_check]
  106. s3 = s3[:sv_check]
  107. vh3 = vh3[:sv_check, :]
  108. A3 = u3 @ np.diag(s3) @ vh3
  109. recon = u @ np.diag(s) @ vh
  110. assert_allclose(np.linalg.norm(A3 - recon), 0, atol=atol)
  111. @pytest.mark.parametrize('shifts', (None, -10, 0, 1, 10, 70))
  112. @pytest.mark.parametrize('dtype', _dtypes[:2])
  113. def test_shifts(shifts, dtype):
  114. rng = np.random.default_rng(0)
  115. n, k = 70, 10
  116. A = rng.random((n, n))
  117. if shifts is not None and ((shifts < 0) or (k > min(n-1-shifts, n))):
  118. with pytest.raises(ValueError):
  119. _svdp(A, k, shifts=shifts, kmax=5*k, irl_mode=True, rng=rng)
  120. else:
  121. _svdp(A, k, shifts=shifts, kmax=5*k, irl_mode=True, rng=rng)
  122. @pytest.mark.slow
  123. @pytest.mark.xfail()
  124. def test_shifts_accuracy():
  125. rng = np.random.default_rng(0)
  126. n, k = 70, 10
  127. A = rng.random((n, n)).astype(np.float64)
  128. u1, s1, vt1, _ = _svdp(A, k, shifts=None, which='SM', irl_mode=True, rng=rng)
  129. u2, s2, vt2, _ = _svdp(A, k, shifts=32, which='SM', irl_mode=True, rng=rng)
  130. # shifts <= 32 doesn't agree with shifts > 32
  131. # Does agree when which='LM' instead of 'SM'
  132. assert_allclose(s1, s2)
  133. @pytest.mark.parametrize('irl_mode', [False, True])
  134. @pytest.mark.parametrize('dtype', (np.float32, np.float64))
  135. def test_thin_hilbert(irl_mode, dtype):
  136. rng = np.random.default_rng(1757951587606893)
  137. m, n = 200, 4
  138. # Generate a Hilbert matrix of size m x n
  139. A = np.array([[1 / (i + j + 1) for j in range(n)] for i in range(m)], dtype=dtype)
  140. uu, ss, vv = np.linalg.svd(A, full_matrices=False)
  141. u, s, vt, _ = _svdp(A, k=4, which='LM', irl_mode=irl_mode, rng=rng)
  142. assert_allclose(s, ss, atol=TOLS[dtype])
  143. # Check orthogonality of singular vectors
  144. assert_allclose(np.eye(u.shape[1]), u.T @ u, atol=TOLS[dtype])
  145. assert_allclose(np.eye(vt.shape[0]), vt @ vt.T, atol=TOLS[dtype])
  146. # Check orthogonality against numpy svd results
  147. assert_allclose(np.abs(uu.T @ u), np.eye(n), atol=TOLS[dtype])
  148. assert_allclose(np.abs(vv @ vt.T), np.eye(n), atol=TOLS[dtype])
  149. @pytest.mark.parametrize('dtype', (np.float32, np.float64, np.complex64, np.complex128))
  150. def test_fat_random(dtype):
  151. rng = np.random.default_rng(1758046113948869)
  152. m, n = 3, 100
  153. A = rng.uniform(size=(m, n)).astype(dtype)
  154. if dtype in (np.complex64, np.complex128):
  155. A += dtype(1j) * rng.uniform(size=(m, n)).astype(dtype)
  156. uu, ss, vv = np.linalg.svd(A, full_matrices=False)
  157. u, s, vt, _ = _svdp(A, k=3, which='LM', irl_mode=True, rng=rng)
  158. assert_allclose(s, ss, atol=TOLS[dtype])
  159. # Check orthogonality of singular vectors
  160. assert_allclose(np.eye(u.shape[1]), u.conj().T @ u, atol=TOLS[dtype])
  161. assert_allclose(np.eye(vt.shape[0]), vt @ vt.conj().T, atol=TOLS[dtype])
  162. # Check orthogonality against numpy svd results
  163. assert_allclose(np.abs(uu.conj().T @ u), np.eye(m), atol=TOLS[dtype])
  164. assert_allclose(np.abs(vv @ vt.conj().T), np.eye(m), atol=TOLS[dtype])