test_decomp_polar.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import pytest
  2. import numpy as np
  3. from numpy.linalg import norm
  4. from numpy.testing import (assert_, assert_allclose, assert_equal)
  5. from scipy.linalg import polar, eigh
  6. diag2 = np.array([[2, 0], [0, 3]])
  7. a13 = np.array([[1, 2, 2]])
  8. precomputed_cases = [
  9. [[[0]], 'right', [[1]], [[0]]],
  10. [[[0]], 'left', [[1]], [[0]]],
  11. [[[9]], 'right', [[1]], [[9]]],
  12. [[[9]], 'left', [[1]], [[9]]],
  13. [diag2, 'right', np.eye(2), diag2],
  14. [diag2, 'left', np.eye(2), diag2],
  15. [a13, 'right', a13/norm(a13[0]), a13.T.dot(a13)/norm(a13[0])],
  16. ]
  17. verify_cases = [
  18. [[1, 2], [3, 4]],
  19. [[1, 2, 3]],
  20. [[1], [2], [3]],
  21. [[1, 2, 3], [3, 4, 0]],
  22. [[1, 2], [3, 4], [5, 5]],
  23. [[1, 2], [3, 4+5j]],
  24. [[1, 2, 3j]],
  25. [[1], [2], [3j]],
  26. [[1, 2, 3+2j], [3, 4-1j, -4j]],
  27. [[1, 2], [3-2j, 4+0.5j], [5, 5]],
  28. [[10000, 10, 1], [-1, 2, 3j], [0, 1, 2]],
  29. np.empty((0, 0)),
  30. np.empty((0, 2)),
  31. np.empty((2, 0)),
  32. ]
  33. def check_precomputed_polar(a, side, expected_u, expected_p):
  34. # Compare the result of the polar decomposition to a
  35. # precomputed result.
  36. u, p = polar(a, side=side)
  37. assert_allclose(u, expected_u, atol=1e-15)
  38. assert_allclose(p, expected_p, atol=1e-15)
  39. def verify_polar(a):
  40. # Compute the polar decomposition, and then verify that
  41. # the result has all the expected properties.
  42. product_atol = np.sqrt(np.finfo(float).eps)
  43. aa = np.asarray(a)
  44. m, n = aa.shape
  45. u, p = polar(a, side='right')
  46. assert_equal(u.shape, (m, n))
  47. assert_equal(p.shape, (n, n))
  48. # a = up
  49. assert_allclose(u.dot(p), a, atol=product_atol)
  50. if m >= n:
  51. assert_allclose(u.conj().T.dot(u), np.eye(n), atol=1e-15)
  52. else:
  53. assert_allclose(u.dot(u.conj().T), np.eye(m), atol=1e-15)
  54. # p is Hermitian positive semidefinite.
  55. assert_allclose(p.conj().T, p)
  56. evals = eigh(p, eigvals_only=True)
  57. nonzero_evals = evals[abs(evals) > 1e-14]
  58. assert_((nonzero_evals >= 0).all())
  59. u, p = polar(a, side='left')
  60. assert_equal(u.shape, (m, n))
  61. assert_equal(p.shape, (m, m))
  62. # a = pu
  63. assert_allclose(p.dot(u), a, atol=product_atol)
  64. if m >= n:
  65. assert_allclose(u.conj().T.dot(u), np.eye(n), atol=1e-15)
  66. else:
  67. assert_allclose(u.dot(u.conj().T), np.eye(m), atol=1e-15)
  68. # p is Hermitian positive semidefinite.
  69. assert_allclose(p.conj().T, p)
  70. evals = eigh(p, eigvals_only=True)
  71. nonzero_evals = evals[abs(evals) > 1e-14]
  72. assert_((nonzero_evals >= 0).all())
  73. def test_precomputed_cases():
  74. for a, side, expected_u, expected_p in precomputed_cases:
  75. check_precomputed_polar(a, side, expected_u, expected_p)
  76. def test_verify_cases():
  77. for a in verify_cases:
  78. verify_polar(a)
  79. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  80. @pytest.mark.parametrize('shape', [(0, 0), (0, 2), (2, 0)])
  81. @pytest.mark.parametrize('side', ['left', 'right'])
  82. def test_empty(dt, shape, side):
  83. a = np.empty(shape, dtype=dt)
  84. m, n = shape
  85. p_shape = (m, m) if side == 'left' else (n, n)
  86. u, p = polar(a, side=side)
  87. u_n, p_n = polar(np.eye(5, dtype=dt))
  88. assert_equal(u.dtype, u_n.dtype)
  89. assert_equal(p.dtype, p_n.dtype)
  90. assert u.shape == shape
  91. assert p.shape == p_shape
  92. assert np.all(p == 0)