test_procrustes.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. from itertools import product, permutations
  2. import numpy as np
  3. import pytest
  4. from numpy.testing import assert_allclose
  5. from pytest import raises as assert_raises
  6. from scipy.linalg import orthogonal_procrustes
  7. from scipy.sparse._sputils import matrix
  8. from scipy._lib._array_api import make_xp_test_case, xp_assert_close
  9. from scipy.conftest import skip_xp_invalid_arg
  10. def _centered(A, xp):
  11. mu = xp.mean(A, axis=0)
  12. return A - mu, mu
  13. @make_xp_test_case(orthogonal_procrustes)
  14. class TestOrthogonalProcrustes:
  15. def test_orthogonal_procrustes_ndim_too_small(self, xp):
  16. rng = np.random.RandomState(1234)
  17. A = xp.asarray(rng.randn(3))
  18. B = xp.asarray(rng.randn(3))
  19. assert_raises(ValueError, orthogonal_procrustes, A, B)
  20. def test_orthogonal_procrustes_shape_mismatch(self, xp):
  21. rng = np.random.RandomState(1234)
  22. shapes = ((3, 3), (3, 4), (4, 3), (4, 4))
  23. for a, b in permutations(shapes, 2):
  24. A = xp.asarray(rng.randn(*a))
  25. B = xp.asarray(rng.randn(*b))
  26. assert_raises(ValueError, orthogonal_procrustes, A, B)
  27. def test_orthogonal_procrustes_checkfinite_exception(self, xp):
  28. rng = np.random.RandomState(1234)
  29. m, n = 2, 3
  30. A_good = rng.randn(m, n)
  31. B_good = rng.randn(m, n)
  32. for bad_value in np.inf, -np.inf, np.nan:
  33. A_bad = A_good.copy()
  34. A_bad[1, 2] = bad_value
  35. B_bad = B_good.copy()
  36. B_bad[1, 2] = bad_value
  37. for A, B in ((A_good, B_bad), (A_bad, B_good), (A_bad, B_bad)):
  38. assert_raises(ValueError, orthogonal_procrustes, xp.asarray(A),
  39. xp.asarray(B))
  40. def test_orthogonal_procrustes_scale_invariance(self, xp):
  41. rng = np.random.RandomState(1234)
  42. m, n = 4, 3
  43. for i in range(3):
  44. A_orig = xp.asarray(rng.randn(m, n))
  45. B_orig = xp.asarray(rng.randn(m, n))
  46. R_orig, s = orthogonal_procrustes(A_orig, B_orig)
  47. for A_scale in np.square(rng.randn(3)):
  48. for B_scale in np.square(rng.randn(3)):
  49. R, s = orthogonal_procrustes(A_orig * xp.asarray(A_scale),
  50. B_orig * xp.asarray(B_scale))
  51. xp_assert_close(R, R_orig)
  52. @skip_xp_invalid_arg()
  53. def test_orthogonal_procrustes_array_conversion(self):
  54. rng = np.random.RandomState(1234)
  55. for m, n in ((6, 4), (4, 4), (4, 6)):
  56. A_arr = rng.randn(m, n)
  57. B_arr = rng.randn(m, n)
  58. As = (A_arr, A_arr.tolist(), matrix(A_arr))
  59. Bs = (B_arr, B_arr.tolist(), matrix(B_arr))
  60. R_arr, s = orthogonal_procrustes(A_arr, B_arr)
  61. AR_arr = A_arr.dot(R_arr)
  62. for A, B in product(As, Bs):
  63. R, s = orthogonal_procrustes(A, B)
  64. AR = A_arr.dot(R)
  65. assert_allclose(AR, AR_arr)
  66. def test_orthogonal_procrustes(self, xp):
  67. rng = np.random.RandomState(1234)
  68. for m, n in ((6, 4), (4, 4), (4, 6)):
  69. # Sample a random target matrix.
  70. B = xp.asarray(rng.randn(m, n))
  71. # Sample a random orthogonal matrix
  72. # by computing eigh of a sampled symmetric matrix.
  73. X = xp.asarray(rng.randn(n, n))
  74. w, V = xp.linalg.eigh(X.T + X)
  75. xp_assert_close(xp.linalg.inv(V), V.T)
  76. # Compute a matrix with a known orthogonal transformation that gives B.
  77. A = B @ V.T
  78. # Check that an orthogonal transformation from A to B can be recovered.
  79. R, s = orthogonal_procrustes(A, B)
  80. xp_assert_close(xp.linalg.inv(R), R.T)
  81. xp_assert_close(A @ R, B)
  82. # Create a perturbed input matrix.
  83. A_perturbed = A + 1e-2 * xp.asarray(rng.randn(m, n))
  84. # Check that the orthogonal procrustes function can find an orthogonal
  85. # transformation that is better than the orthogonal transformation
  86. # computed from the original input matrix.
  87. R_prime, s = orthogonal_procrustes(A_perturbed, B)
  88. xp_assert_close(xp.linalg.inv(R_prime), R_prime.T)
  89. # Compute the naive and optimal transformations of the perturbed input.
  90. naive_approx = A_perturbed @ R
  91. optim_approx = A_perturbed @ R_prime
  92. # Compute the Frobenius norm errors of the matrix approximations.
  93. naive_approx_error = xp.linalg.matrix_norm(naive_approx - B, ord='fro')
  94. optim_approx_error = xp.linalg.matrix_norm(optim_approx - B, ord='fro')
  95. # Check that the orthogonal Procrustes approximation is better.
  96. assert xp.all(optim_approx_error < naive_approx_error)
  97. def test_orthogonal_procrustes_exact_example(self, xp):
  98. # Check a small application.
  99. # It uses translation, scaling, reflection, and rotation.
  100. #
  101. # |
  102. # a b |
  103. # |
  104. # d c | w
  105. # |
  106. # --------+--- x ----- z ---
  107. # |
  108. # | y
  109. # |
  110. #
  111. A_orig = xp.asarray([[-3, 3], [-2, 3], [-2, 2], [-3, 2]], dtype=xp.float64)
  112. B_orig = xp.asarray([[3, 2], [1, 0], [3, -2], [5, 0]], dtype=xp.float64)
  113. A, A_mu = _centered(A_orig, xp)
  114. B, B_mu = _centered(B_orig, xp)
  115. R, s = orthogonal_procrustes(A, B)
  116. scale = s / xp.linalg.matrix_norm(A)**2
  117. B_approx = scale * A @ R + B_mu
  118. xp_assert_close(B_approx, B_orig, atol=1e-8)
  119. def test_orthogonal_procrustes_stretched_example(self, xp):
  120. # Try again with a target with a stretched y axis.
  121. A_orig = xp.asarray([[-3, 3], [-2, 3], [-2, 2], [-3, 2]], dtype=xp.float64)
  122. B_orig = xp.asarray([[3, 40], [1, 0], [3, -40], [5, 0]], dtype=xp.float64)
  123. A, A_mu = _centered(A_orig, xp)
  124. B, B_mu = _centered(B_orig, xp)
  125. R, s = orthogonal_procrustes(A, B)
  126. scale = s / xp.linalg.matrix_norm(A)**2
  127. B_approx = scale * A @ R + B_mu
  128. expected = xp.asarray([[3, 21], [-18, 0], [3, -21], [24, 0]], dtype=xp.float64)
  129. xp_assert_close(B_approx, expected, atol=1e-8)
  130. # Check disparity symmetry.
  131. expected_disparity = xp.asarray(0.4501246882793018, dtype=xp.float64)[()]
  132. AB_disparity = (xp.linalg.matrix_norm(B_approx - B_orig)
  133. / xp.linalg.matrix_norm(B))**2
  134. xp_assert_close(AB_disparity, expected_disparity)
  135. R, s = orthogonal_procrustes(B, A)
  136. scale = s / xp.linalg.matrix_norm(B)**2
  137. A_approx = scale * B @ R + A_mu
  138. BA_disparity = (xp.linalg.matrix_norm(A_approx - A_orig)
  139. / xp.linalg.matrix_norm(A))**2
  140. xp_assert_close(BA_disparity, expected_disparity)
  141. def test_orthogonal_procrustes_skbio_example(self, xp):
  142. # This transformation is also exact.
  143. # It uses translation, scaling, and reflection.
  144. #
  145. # |
  146. # | a
  147. # | b
  148. # | c d
  149. # --+---------
  150. # |
  151. # | w
  152. # |
  153. # | x
  154. # |
  155. # | z y
  156. # |
  157. #
  158. A_orig = xp.asarray([[4, -2], [4, -4], [4, -6], [2, -6]], dtype=xp.float64)
  159. B_orig = xp.asarray([[1, 3], [1, 2], [1, 1], [2, 1]], dtype=xp.float64)
  160. B_standardized = xp.asarray([[-0.13363062, 0.6681531],
  161. [-0.13363062, 0.13363062],
  162. [-0.13363062, -0.40089186],
  163. [0.40089186, -0.40089186]], dtype=xp.float64)
  164. A, A_mu = _centered(A_orig, xp)
  165. B, B_mu = _centered(B_orig, xp)
  166. R, s = orthogonal_procrustes(A, B)
  167. scale = s / xp.linalg.matrix_norm(A)**2
  168. B_approx = scale * A @ R + B_mu
  169. xp_assert_close(B_approx, B_orig)
  170. xp_assert_close(B / xp.linalg.matrix_norm(B), B_standardized)
  171. def test_empty(self, xp):
  172. a = xp.empty((0, 0))
  173. r, s = orthogonal_procrustes(a, a)
  174. xp_assert_close(r, xp.empty((0, 0)))
  175. a = xp.empty((0, 3))
  176. r, s = orthogonal_procrustes(a, a)
  177. xp_assert_close(r, xp.eye(3))
  178. @pytest.mark.parametrize('shape', [(4, 5), (5, 5), (5, 4)])
  179. def test_unitary(self, shape, xp):
  180. # gh-12071 added support for unitary matrices; check that it
  181. # works as intended.
  182. m, n = shape
  183. rng = np.random.default_rng(589234981235)
  184. A = xp.asarray(rng.random(shape) + rng.random(shape) * 1j)
  185. Q = xp.asarray(rng.random((n, n)) + rng.random((n, n)) * 1j)
  186. Q, _ = xp.linalg.qr(Q)
  187. B = A @ Q
  188. R, scale = orthogonal_procrustes(A, B)
  189. xp_assert_close(R @ xp.conj(R).T, xp.eye(n, dtype=xp.complex128), atol=1e-14)
  190. xp_assert_close(A @ Q, B)
  191. if shape != (4, 5): # solution is unique
  192. xp_assert_close(R, Q)
  193. _, s, _ = xp.linalg.svd(xp.conj(A).T @ B)
  194. xp_assert_close(scale, xp.sum(s))