_procrustes.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. """
  2. Solve the orthogonal Procrustes problem.
  3. """
  4. from scipy._lib._util import _apply_over_batch
  5. from ._decomp_svd import svd
  6. from scipy._lib._array_api import array_namespace, xp_capabilities, _asarray, is_numpy
  7. __all__ = ['orthogonal_procrustes']
  8. @xp_capabilities(
  9. jax_jit=False,
  10. skip_backends=[("dask.array", "full_matrices=True is not supported by dask")],
  11. )
  12. @_apply_over_batch(('A', 2), ('B', 2))
  13. def orthogonal_procrustes(A, B, check_finite=True):
  14. """
  15. Compute the matrix solution of the orthogonal (or unitary) Procrustes problem.
  16. Given matrices `A` and `B` of the same shape, find an orthogonal (or unitary in
  17. the case of complex input) matrix `R` that most closely maps `A` to `B` using the
  18. algorithm given in [1]_.
  19. Parameters
  20. ----------
  21. A : (M, N) array_like
  22. Matrix to be mapped.
  23. B : (M, N) array_like
  24. Target matrix.
  25. check_finite : bool, optional
  26. Whether to check that the input matrices contain only finite numbers.
  27. Disabling may give a performance gain, but may result in problems
  28. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  29. Returns
  30. -------
  31. R : (N, N) ndarray
  32. The matrix solution of the orthogonal Procrustes problem.
  33. Minimizes the Frobenius norm of ``(A @ R) - B``, subject to
  34. ``R.conj().T @ R = I``.
  35. scale : float
  36. Sum of the singular values of ``A.conj().T @ B``.
  37. Raises
  38. ------
  39. ValueError
  40. If the input array shapes don't match or if check_finite is True and
  41. the arrays contain Inf or NaN.
  42. Notes
  43. -----
  44. Note that unlike higher level Procrustes analyses of spatial data, this
  45. function only uses orthogonal transformations like rotations and
  46. reflections, and it does not use scaling or translation.
  47. References
  48. ----------
  49. .. [1] Peter H. Schonemann, "A generalized solution of the orthogonal
  50. Procrustes problem", Psychometrica -- Vol. 31, No. 1, March, 1966.
  51. :doi:`10.1007/BF02289451`
  52. Examples
  53. --------
  54. >>> import numpy as np
  55. >>> from scipy.linalg import orthogonal_procrustes
  56. >>> A = np.array([[ 2, 0, 1], [-2, 0, 0]])
  57. Flip the order of columns and check for the anti-diagonal mapping
  58. >>> R, sca = orthogonal_procrustes(A, np.fliplr(A))
  59. >>> R
  60. array([[-5.34384992e-17, 0.00000000e+00, 1.00000000e+00],
  61. [ 0.00000000e+00, 1.00000000e+00, 0.00000000e+00],
  62. [ 1.00000000e+00, 0.00000000e+00, -7.85941422e-17]])
  63. >>> sca
  64. 9.0
  65. As an example of the unitary Procrustes problem, generate a
  66. random complex matrix ``A``, a random unitary matrix ``Q``,
  67. and their product ``B``.
  68. >>> shape = (4, 4)
  69. >>> rng = np.random.default_rng(589234981235)
  70. >>> A = rng.random(shape) + rng.random(shape)*1j
  71. >>> Q = rng.random(shape) + rng.random(shape)*1j
  72. >>> Q, _ = np.linalg.qr(Q)
  73. >>> B = A @ Q
  74. `orthogonal_procrustes` recovers the unitary matrix ``Q``
  75. from ``A`` and ``B``.
  76. >>> R, _ = orthogonal_procrustes(A, B)
  77. >>> np.allclose(R, Q)
  78. True
  79. """
  80. xp = array_namespace(A, B)
  81. A = _asarray(A, xp=xp, check_finite=check_finite, subok=True)
  82. B = _asarray(B, xp=xp, check_finite=check_finite, subok=True)
  83. if A.ndim != 2:
  84. raise ValueError(f'expected ndim to be 2, but observed {A.ndim}')
  85. if A.shape != B.shape:
  86. raise ValueError(f'the shapes of A and B differ ({A.shape} vs {B.shape})')
  87. # Be clever with transposes, with the intention to save memory.
  88. # The conjugate has no effect for real inputs, but gives the correct solution
  89. # for complex inputs.
  90. if is_numpy(xp):
  91. u, w, vt = svd((B.T @ xp.conj(A)).T)
  92. else:
  93. u, w, vt = xp.linalg.svd((B.T @ xp.conj(A)).T)
  94. R = u @ vt
  95. scale = xp.sum(w)
  96. return R, scale