test_solvers.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862
  1. import os
  2. import numpy as np
  3. from numpy.testing import assert_array_almost_equal, assert_allclose
  4. import pytest
  5. from pytest import raises as assert_raises
  6. from scipy.linalg import solve_sylvester
  7. from scipy.linalg import solve_continuous_lyapunov, solve_discrete_lyapunov
  8. from scipy.linalg import solve_continuous_are, solve_discrete_are
  9. from scipy.linalg import block_diag, solve, LinAlgError
  10. from scipy.sparse._sputils import matrix
  11. from scipy.conftest import skip_xp_invalid_arg
  12. # dtypes for testing size-0 case following precedent set in gh-20295
  13. dtypes = [int, float, np.float32, complex, np.complex64]
  14. def _load_data(name):
  15. """
  16. Load npz data file under data/
  17. Returns a copy of the data, rather than keeping the npz file open.
  18. """
  19. filename = os.path.join(os.path.abspath(os.path.dirname(__file__)),
  20. 'data', name)
  21. with np.load(filename) as f:
  22. return dict(f.items())
  23. class TestSolveLyapunov:
  24. cases = [
  25. # empty case
  26. (np.empty((0, 0)),
  27. np.empty((0, 0))),
  28. (np.array([[1, 2], [3, 4]]),
  29. np.array([[9, 10], [11, 12]])),
  30. # a, q all complex.
  31. (np.array([[1.0+1j, 2.0], [3.0-4.0j, 5.0]]),
  32. np.array([[2.0-2j, 2.0+2j], [-1.0-1j, 2.0]])),
  33. # a real; q complex.
  34. (np.array([[1.0, 2.0], [3.0, 5.0]]),
  35. np.array([[2.0-2j, 2.0+2j], [-1.0-1j, 2.0]])),
  36. # a complex; q real.
  37. (np.array([[1.0+1j, 2.0], [3.0-4.0j, 5.0]]),
  38. np.array([[2.0, 2.0], [-1.0, 2.0]])),
  39. # An example from Kitagawa, 1977
  40. (np.array([[3, 9, 5, 1, 4], [1, 2, 3, 8, 4], [4, 6, 6, 6, 3],
  41. [1, 5, 2, 0, 7], [5, 3, 3, 1, 5]]),
  42. np.array([[2, 4, 1, 0, 1], [4, 1, 0, 2, 0], [1, 0, 3, 0, 3],
  43. [0, 2, 0, 1, 0], [1, 0, 3, 0, 4]])),
  44. # Companion matrix example. a complex; q real; a.shape[0] = 11
  45. (np.array([[0.100+0.j, 0.091+0.j, 0.082+0.j, 0.073+0.j, 0.064+0.j,
  46. 0.055+0.j, 0.046+0.j, 0.037+0.j, 0.028+0.j, 0.019+0.j,
  47. 0.010+0.j],
  48. [1.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  49. 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  50. 0.000+0.j],
  51. [0.000+0.j, 1.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  52. 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  53. 0.000+0.j],
  54. [0.000+0.j, 0.000+0.j, 1.000+0.j, 0.000+0.j, 0.000+0.j,
  55. 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  56. 0.000+0.j],
  57. [0.000+0.j, 0.000+0.j, 0.000+0.j, 1.000+0.j, 0.000+0.j,
  58. 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  59. 0.000+0.j],
  60. [0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 1.000+0.j,
  61. 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  62. 0.000+0.j],
  63. [0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  64. 1.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  65. 0.000+0.j],
  66. [0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  67. 0.000+0.j, 1.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  68. 0.000+0.j],
  69. [0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  70. 0.000+0.j, 0.000+0.j, 1.000+0.j, 0.000+0.j, 0.000+0.j,
  71. 0.000+0.j],
  72. [0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  73. 0.000+0.j, 0.000+0.j, 0.000+0.j, 1.000+0.j, 0.000+0.j,
  74. 0.000+0.j],
  75. [0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  76. 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 1.000+0.j,
  77. 0.000+0.j]]),
  78. np.eye(11)),
  79. # https://github.com/scipy/scipy/issues/4176
  80. (matrix([[0, 1], [-1/2, -1]]),
  81. (matrix([0, 3]).T @ matrix([0, 3]).T.T)),
  82. # https://github.com/scipy/scipy/issues/4176
  83. (matrix([[0, 1], [-1/2, -1]]),
  84. (np.array(matrix([0, 3]).T @ matrix([0, 3]).T.T))),
  85. ]
  86. def test_continuous_squareness_and_shape(self):
  87. nsq = np.ones((3, 2))
  88. sq = np.eye(3)
  89. assert_raises(ValueError, solve_continuous_lyapunov, nsq, sq)
  90. assert_raises(ValueError, solve_continuous_lyapunov, sq, nsq)
  91. assert_raises(ValueError, solve_continuous_lyapunov, sq, np.eye(2))
  92. def check_continuous_case(self, a, q):
  93. x = solve_continuous_lyapunov(a, q)
  94. assert_array_almost_equal(
  95. np.dot(a, x) + np.dot(x, a.conj().transpose()), q)
  96. def check_discrete_case(self, a, q, method=None):
  97. x = solve_discrete_lyapunov(a, q, method=method)
  98. assert_array_almost_equal(
  99. np.dot(np.dot(a, x), a.conj().transpose()) - x, -1.0*q)
  100. @skip_xp_invalid_arg
  101. def test_cases(self):
  102. for case in self.cases:
  103. self.check_continuous_case(case[0], case[1])
  104. self.check_discrete_case(case[0], case[1])
  105. self.check_discrete_case(case[0], case[1], method='direct')
  106. self.check_discrete_case(case[0], case[1], method='bilinear')
  107. @pytest.mark.parametrize("dtype_a", dtypes)
  108. @pytest.mark.parametrize("dtype_q", dtypes)
  109. def test_size_0(self, dtype_a, dtype_q):
  110. rng = np.random.default_rng(234598235)
  111. a = np.zeros((0, 0), dtype=dtype_a)
  112. q = np.zeros((0, 0), dtype=dtype_q)
  113. res = solve_continuous_lyapunov(a, q)
  114. a = (rng.random((5, 5))*100).astype(dtype_a)
  115. q = (rng.random((5, 5))*100).astype(dtype_q)
  116. ref = solve_continuous_lyapunov(a, q)
  117. assert res.shape == (0, 0)
  118. assert res.dtype == ref.dtype
  119. class TestSolveContinuousAre:
  120. mat6 = _load_data('carex_6_data.npz')
  121. mat15 = _load_data('carex_15_data.npz')
  122. mat18 = _load_data('carex_18_data.npz')
  123. mat19 = _load_data('carex_19_data.npz')
  124. mat20 = _load_data('carex_20_data.npz')
  125. cases = [
  126. # Carex examples taken from (with default parameters):
  127. # [1] P.BENNER, A.J. LAUB, V. MEHRMANN: 'A Collection of Benchmark
  128. # Examples for the Numerical Solution of Algebraic Riccati
  129. # Equations II: Continuous-Time Case', Tech. Report SPC 95_23,
  130. # Fak. f. Mathematik, TU Chemnitz-Zwickau (Germany), 1995.
  131. #
  132. # The format of the data is (a, b, q, r, knownfailure), where
  133. # knownfailure is None if the test passes or a string
  134. # indicating the reason for failure.
  135. #
  136. # Test Case 0: carex #1
  137. (np.diag([1.], 1),
  138. np.array([[0], [1]]),
  139. block_diag(1., 2.),
  140. 1,
  141. None),
  142. # Test Case 1: carex #2
  143. (np.array([[4, 3], [-4.5, -3.5]]),
  144. np.array([[1], [-1]]),
  145. np.array([[9, 6], [6, 4.]]),
  146. 1,
  147. None),
  148. # Test Case 2: carex #3
  149. (np.array([[0, 1, 0, 0],
  150. [0, -1.89, 0.39, -5.53],
  151. [0, -0.034, -2.98, 2.43],
  152. [0.034, -0.0011, -0.99, -0.21]]),
  153. np.array([[0, 0], [0.36, -1.6], [-0.95, -0.032], [0.03, 0]]),
  154. np.array([[2.313, 2.727, 0.688, 0.023],
  155. [2.727, 4.271, 1.148, 0.323],
  156. [0.688, 1.148, 0.313, 0.102],
  157. [0.023, 0.323, 0.102, 0.083]]),
  158. np.eye(2),
  159. None),
  160. # Test Case 3: carex #4
  161. (np.array([[-0.991, 0.529, 0, 0, 0, 0, 0, 0],
  162. [0.522, -1.051, 0.596, 0, 0, 0, 0, 0],
  163. [0, 0.522, -1.118, 0.596, 0, 0, 0, 0],
  164. [0, 0, 0.522, -1.548, 0.718, 0, 0, 0],
  165. [0, 0, 0, 0.922, -1.64, 0.799, 0, 0],
  166. [0, 0, 0, 0, 0.922, -1.721, 0.901, 0],
  167. [0, 0, 0, 0, 0, 0.922, -1.823, 1.021],
  168. [0, 0, 0, 0, 0, 0, 0.922, -1.943]]),
  169. np.array([[3.84, 4.00, 37.60, 3.08, 2.36, 2.88, 3.08, 3.00],
  170. [-2.88, -3.04, -2.80, -2.32, -3.32, -3.82, -4.12, -3.96]]
  171. ).T * 0.001,
  172. np.array([[1.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.1],
  173. [0.0, 1.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0],
  174. [0.0, 0.0, 1.0, 0.0, 0.0, 0.5, 0.0, 0.0],
  175. [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
  176. [0.5, 0.1, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0],
  177. [0.0, 0.0, 0.5, 0.0, 0.0, 0.1, 0.0, 0.0],
  178. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.0],
  179. [0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1]]),
  180. np.eye(2),
  181. None),
  182. # Test Case 4: carex #5
  183. (np.array(
  184. [[-4.019, 5.120, 0., 0., -2.082, 0., 0., 0., 0.870],
  185. [-0.346, 0.986, 0., 0., -2.340, 0., 0., 0., 0.970],
  186. [-7.909, 15.407, -4.069, 0., -6.450, 0., 0., 0., 2.680],
  187. [-21.816, 35.606, -0.339, -3.870, -17.800, 0., 0., 0., 7.390],
  188. [-60.196, 98.188, -7.907, 0.340, -53.008, 0., 0., 0., 20.400],
  189. [0, 0, 0, 0, 94.000, -147.200, 0., 53.200, 0.],
  190. [0, 0, 0, 0, 0, 94.000, -147.200, 0, 0],
  191. [0, 0, 0, 0, 0, 12.800, 0.000, -31.600, 0],
  192. [0, 0, 0, 0, 12.800, 0.000, 0.000, 18.800, -31.600]]),
  193. np.array([[0.010, -0.011, -0.151],
  194. [0.003, -0.021, 0.000],
  195. [0.009, -0.059, 0.000],
  196. [0.024, -0.162, 0.000],
  197. [0.068, -0.445, 0.000],
  198. [0.000, 0.000, 0.000],
  199. [0.000, 0.000, 0.000],
  200. [0.000, 0.000, 0.000],
  201. [0.000, 0.000, 0.000]]),
  202. np.eye(9),
  203. np.eye(3),
  204. None),
  205. # Test Case 5: carex #6
  206. (mat6['A'], mat6['B'], mat6['Q'], mat6['R'], None),
  207. # Test Case 6: carex #7
  208. (np.array([[1, 0], [0, -2.]]),
  209. np.array([[1e-6], [0]]),
  210. np.ones((2, 2)),
  211. 1.,
  212. 'Bad residual accuracy'),
  213. # Test Case 7: carex #8
  214. (block_diag(-0.1, -0.02),
  215. np.array([[0.100, 0.000], [0.001, 0.010]]),
  216. np.array([[100, 1000], [1000, 10000]]),
  217. np.ones((2, 2)) + block_diag(1e-6, 0),
  218. None),
  219. # Test Case 8: carex #9
  220. (np.array([[0, 1e6], [0, 0]]),
  221. np.array([[0], [1.]]),
  222. np.eye(2),
  223. 1.,
  224. None),
  225. # Test Case 9: carex #10
  226. (np.array([[1.0000001, 1], [1., 1.0000001]]),
  227. np.eye(2),
  228. np.eye(2),
  229. np.eye(2),
  230. None),
  231. # Test Case 10: carex #11
  232. (np.array([[3, 1.], [4, 2]]),
  233. np.array([[1], [1]]),
  234. np.array([[-11, -5], [-5, -2.]]),
  235. 1.,
  236. None),
  237. # Test Case 11: carex #12
  238. (np.array([[7000000., 2000000., -0.],
  239. [2000000., 6000000., -2000000.],
  240. [0., -2000000., 5000000.]]) / 3,
  241. np.eye(3),
  242. np.array([[1., -2., -2.], [-2., 1., -2.], [-2., -2., 1.]]).dot(
  243. np.diag([1e-6, 1, 1e6])).dot(
  244. np.array([[1., -2., -2.], [-2., 1., -2.], [-2., -2., 1.]])) / 9,
  245. np.eye(3) * 1e6,
  246. 'Bad Residual Accuracy'),
  247. # Test Case 12: carex #13
  248. (np.array([[0, 0.4, 0, 0],
  249. [0, 0, 0.345, 0],
  250. [0, -0.524e6, -0.465e6, 0.262e6],
  251. [0, 0, 0, -1e6]]),
  252. np.array([[0, 0, 0, 1e6]]).T,
  253. np.diag([1, 0, 1, 0]),
  254. 1.,
  255. None),
  256. # Test Case 13: carex #14
  257. (np.array([[-1e-6, 1, 0, 0],
  258. [-1, -1e-6, 0, 0],
  259. [0, 0, 1e-6, 1],
  260. [0, 0, -1, 1e-6]]),
  261. np.ones((4, 1)),
  262. np.ones((4, 4)),
  263. 1.,
  264. None),
  265. # Test Case 14: carex #15
  266. (mat15['A'], mat15['B'], mat15['Q'], mat15['R'], None),
  267. # Test Case 15: carex #16
  268. (np.eye(64, 64, k=-1) + np.eye(64, 64)*(-2.) + np.rot90(
  269. block_diag(1, np.zeros((62, 62)), 1)) + np.eye(64, 64, k=1),
  270. np.eye(64),
  271. np.eye(64),
  272. np.eye(64),
  273. None),
  274. # Test Case 16: carex #17
  275. (np.diag(np.ones((20, )), 1),
  276. np.flipud(np.eye(21, 1)),
  277. np.eye(21, 1) * np.eye(21, 1).T,
  278. 1,
  279. 'Bad Residual Accuracy'),
  280. # Test Case 17: carex #18
  281. (mat18['A'], mat18['B'], mat18['Q'], mat18['R'], None),
  282. # Test Case 18: carex #19
  283. (mat19['A'], mat19['B'], mat19['Q'], mat19['R'],
  284. 'Bad Residual Accuracy'),
  285. # Test Case 19: carex #20
  286. (mat20['A'], mat20['B'], mat20['Q'], mat20['R'],
  287. 'Bad Residual Accuracy')
  288. ]
  289. # Makes the minimum precision requirements customized to the test.
  290. # Here numbers represent the number of decimals that agrees with zero
  291. # matrix when the solution x is plugged in to the equation.
  292. #
  293. # res = array([[8e-3,1e-16],[1e-16,1e-20]]) --> min_decimal[k] = 2
  294. #
  295. # If the test is failing use "None" for that entry.
  296. #
  297. min_decimal = (14, 12, 13, 14, 11, 6, None, 5, 7, 14, 14,
  298. None, 9, 14, 13, 14, None, 12, None, None)
  299. @pytest.mark.parametrize("j, case", enumerate(cases))
  300. def test_solve_continuous_are(self, j, case):
  301. """Checks if 0 = XA + A'X - XB(R)^{-1} B'X + Q is true"""
  302. a, b, q, r, knownfailure = case
  303. if knownfailure:
  304. pytest.xfail(reason=knownfailure)
  305. dec = self.min_decimal[j]
  306. x = solve_continuous_are(a, b, q, r)
  307. res = x @ a + a.conj().T @ x + q
  308. out_fact = x @ b
  309. res -= out_fact @ solve(np.atleast_2d(r), out_fact.conj().T)
  310. assert_array_almost_equal(res, np.zeros_like(res), decimal=dec)
  311. class TestSolveDiscreteAre:
  312. cases = [
  313. # Darex examples taken from (with default parameters):
  314. # [1] P.BENNER, A.J. LAUB, V. MEHRMANN: 'A Collection of Benchmark
  315. # Examples for the Numerical Solution of Algebraic Riccati
  316. # Equations II: Discrete-Time Case', Tech. Report SPC 95_23,
  317. # Fak. f. Mathematik, TU Chemnitz-Zwickau (Germany), 1995.
  318. # [2] T. GUDMUNDSSON, C. KENNEY, A.J. LAUB: 'Scaling of the
  319. # Discrete-Time Algebraic Riccati Equation to Enhance Stability
  320. # of the Schur Solution Method', IEEE Trans.Aut.Cont., vol.37(4)
  321. #
  322. # The format of the data is (a, b, q, r, knownfailure), where
  323. # knownfailure is None if the test passes or a string
  324. # indicating the reason for failure.
  325. #
  326. # TEST CASE 0 : Complex a; real b, q, r
  327. (np.array([[2, 1-2j], [0, -3j]]),
  328. np.array([[0], [1]]),
  329. np.array([[1, 0], [0, 2]]),
  330. np.array([[1]]),
  331. None),
  332. # TEST CASE 1 :Real a, q, r; complex b
  333. (np.array([[2, 1], [0, -1]]),
  334. np.array([[-2j], [1j]]),
  335. np.array([[1, 0], [0, 2]]),
  336. np.array([[1]]),
  337. None),
  338. # TEST CASE 2 : Real a, b; complex q, r
  339. (np.array([[3, 1], [0, -1]]),
  340. np.array([[1, 2], [1, 3]]),
  341. np.array([[1, 1+1j], [1-1j, 2]]),
  342. np.array([[2, -2j], [2j, 3]]),
  343. None),
  344. # TEST CASE 3 : User-reported gh-2251 (Trac #1732)
  345. (np.array([[0.63399379, 0.54906824, 0.76253406],
  346. [0.5404729, 0.53745766, 0.08731853],
  347. [0.27524045, 0.84922129, 0.4681622]]),
  348. np.array([[0.96861695], [0.05532739], [0.78934047]]),
  349. np.eye(3),
  350. np.eye(1),
  351. None),
  352. # TEST CASE 4 : darex #1
  353. (np.array([[4, 3], [-4.5, -3.5]]),
  354. np.array([[1], [-1]]),
  355. np.array([[9, 6], [6, 4]]),
  356. np.array([[1]]),
  357. None),
  358. # TEST CASE 5 : darex #2
  359. (np.array([[0.9512, 0], [0, 0.9048]]),
  360. np.array([[4.877, 4.877], [-1.1895, 3.569]]),
  361. np.array([[0.005, 0], [0, 0.02]]),
  362. np.array([[1/3, 0], [0, 3]]),
  363. None),
  364. # TEST CASE 6 : darex #3
  365. (np.array([[2, -1], [1, 0]]),
  366. np.array([[1], [0]]),
  367. np.array([[0, 0], [0, 1]]),
  368. np.array([[0]]),
  369. None),
  370. # TEST CASE 7 : darex #4 (skipped the gen. Ric. term S)
  371. (np.array([[0, 1], [0, -1]]),
  372. np.array([[1, 0], [2, 1]]),
  373. np.array([[-4, -4], [-4, 7]]) * (1/11),
  374. np.array([[9, 3], [3, 1]]),
  375. None),
  376. # TEST CASE 8 : darex #5
  377. (np.array([[0, 1], [0, 0]]),
  378. np.array([[0], [1]]),
  379. np.array([[1, 2], [2, 4]]),
  380. np.array([[1]]),
  381. None),
  382. # TEST CASE 9 : darex #6
  383. (np.array([[0.998, 0.067, 0, 0],
  384. [-.067, 0.998, 0, 0],
  385. [0, 0, 0.998, 0.153],
  386. [0, 0, -.153, 0.998]]),
  387. np.array([[0.0033, 0.0200],
  388. [0.1000, -.0007],
  389. [0.0400, 0.0073],
  390. [-.0028, 0.1000]]),
  391. np.array([[1.87, 0, 0, -0.244],
  392. [0, 0.744, 0.205, 0],
  393. [0, 0.205, 0.589, 0],
  394. [-0.244, 0, 0, 1.048]]),
  395. np.eye(2),
  396. None),
  397. # TEST CASE 10 : darex #7
  398. (np.array([[0.984750, -.079903, 0.0009054, -.0010765],
  399. [0.041588, 0.998990, -.0358550, 0.0126840],
  400. [-.546620, 0.044916, -.3299100, 0.1931800],
  401. [2.662400, -.100450, -.9245500, -.2632500]]),
  402. np.array([[0.0037112, 0.0007361],
  403. [-.0870510, 9.3411e-6],
  404. [-1.198440, -4.1378e-4],
  405. [-3.192700, 9.2535e-4]]),
  406. np.eye(4)*1e-2,
  407. np.eye(2),
  408. None),
  409. # TEST CASE 11 : darex #8
  410. (np.array([[-0.6000000, -2.2000000, -3.6000000, -5.4000180],
  411. [1.0000000, 0.6000000, 0.8000000, 3.3999820],
  412. [0.0000000, 1.0000000, 1.8000000, 3.7999820],
  413. [0.0000000, 0.0000000, 0.0000000, -0.9999820]]),
  414. np.array([[1.0, -1.0, -1.0, -1.0],
  415. [0.0, 1.0, -1.0, -1.0],
  416. [0.0, 0.0, 1.0, -1.0],
  417. [0.0, 0.0, 0.0, 1.0]]),
  418. np.array([[2, 1, 3, 6],
  419. [1, 2, 2, 5],
  420. [3, 2, 6, 11],
  421. [6, 5, 11, 22]]),
  422. np.eye(4),
  423. None),
  424. # TEST CASE 12 : darex #9
  425. (np.array([[95.4070, 1.9643, 0.3597, 0.0673, 0.0190],
  426. [40.8490, 41.3170, 16.0840, 4.4679, 1.1971],
  427. [12.2170, 26.3260, 36.1490, 15.9300, 12.3830],
  428. [4.1118, 12.8580, 27.2090, 21.4420, 40.9760],
  429. [0.1305, 0.5808, 1.8750, 3.6162, 94.2800]]) * 0.01,
  430. np.array([[0.0434, -0.0122],
  431. [2.6606, -1.0453],
  432. [3.7530, -5.5100],
  433. [3.6076, -6.6000],
  434. [0.4617, -0.9148]]) * 0.01,
  435. np.eye(5),
  436. np.eye(2),
  437. None),
  438. # TEST CASE 13 : darex #10
  439. (np.kron(np.eye(2), np.diag([1, 1], k=1)),
  440. np.kron(np.eye(2), np.array([[0], [0], [1]])),
  441. np.array([[1, 1, 0, 0, 0, 0],
  442. [1, 1, 0, 0, 0, 0],
  443. [0, 0, 0, 0, 0, 0],
  444. [0, 0, 0, 1, -1, 0],
  445. [0, 0, 0, -1, 1, 0],
  446. [0, 0, 0, 0, 0, 0]]),
  447. np.array([[3, 0], [0, 1]]),
  448. None),
  449. # TEST CASE 14 : darex #11
  450. (0.001 * np.array(
  451. [[870.1, 135.0, 11.59, .5014, -37.22, .3484, 0, 4.242, 7.249],
  452. [76.55, 897.4, 12.72, 0.5504, -40.16, .3743, 0, 4.53, 7.499],
  453. [-127.2, 357.5, 817, 1.455, -102.8, .987, 0, 11.85, 18.72],
  454. [-363.5, 633.9, 74.91, 796.6, -273.5, 2.653, 0, 31.72, 48.82],
  455. [-960, 1645.9, -128.9, -5.597, 71.42, 7.108, 0, 84.52, 125.9],
  456. [-664.4, 112.96, -88.89, -3.854, 84.47, 13.6, 0, 144.3, 101.6],
  457. [-410.2, 693, -54.71, -2.371, 66.49, 12.49, .1063, 99.97, 69.67],
  458. [-179.9, 301.7, -23.93, -1.035, 60.59, 22.16, 0, 213.9, 35.54],
  459. [-345.1, 580.4, -45.96, -1.989, 105.6, 19.86, 0, 219.1, 215.2]]),
  460. np.array([[4.7600, -0.5701, -83.6800],
  461. [0.8790, -4.7730, -2.7300],
  462. [1.4820, -13.1200, 8.8760],
  463. [3.8920, -35.1300, 24.8000],
  464. [10.3400, -92.7500, 66.8000],
  465. [7.2030, -61.5900, 38.3400],
  466. [4.4540, -36.8300, 20.2900],
  467. [1.9710, -15.5400, 6.9370],
  468. [3.7730, -30.2800, 14.6900]]) * 0.001,
  469. np.diag([50, 0, 0, 0, 50, 0, 0, 0, 0]),
  470. np.eye(3),
  471. None),
  472. # TEST CASE 15 : darex #12 - numerically least accurate example
  473. (np.array([[0, 1e6], [0, 0]]),
  474. np.array([[0], [1]]),
  475. np.eye(2),
  476. np.array([[1]]),
  477. None),
  478. # TEST CASE 16 : darex #13
  479. (np.array([[16, 10, -2],
  480. [10, 13, -8],
  481. [-2, -8, 7]]) * (1/9),
  482. np.eye(3),
  483. 1e6 * np.eye(3),
  484. 1e6 * np.eye(3),
  485. None),
  486. # TEST CASE 17 : darex #14
  487. (np.array([[1 - 1/1e8, 0, 0, 0],
  488. [1, 0, 0, 0],
  489. [0, 1, 0, 0],
  490. [0, 0, 1, 0]]),
  491. np.array([[1e-08], [0], [0], [0]]),
  492. np.diag([0, 0, 0, 1]),
  493. np.array([[0.25]]),
  494. None),
  495. # TEST CASE 18 : darex #15
  496. (np.eye(100, k=1),
  497. np.flipud(np.eye(100, 1)),
  498. np.eye(100),
  499. np.array([[1]]),
  500. None)
  501. ]
  502. # Makes the minimum precision requirements customized to the test.
  503. # Here numbers represent the number of decimals that agrees with zero
  504. # matrix when the solution x is plugged in to the equation.
  505. #
  506. # res = array([[8e-3,1e-16],[1e-16,1e-20]]) --> min_decimal[k] = 2
  507. #
  508. # If the test is failing use "None" for that entry.
  509. #
  510. min_decimal = (12, 14, 13, 14, 13, 16, 18, 14, 14, 13,
  511. 14, 13, 13, 14, 12, 2, 4, 6, 10)
  512. max_tol = [1.5 * 10**-ind for ind in min_decimal]
  513. # relaxed tolerance in gh-18012 after bump to OpenBLAS
  514. max_tol[11] = 2.5e-13
  515. # relaxed tolerance in gh-20335 for linux-aarch64 build on Cirrus
  516. # with OpenBLAS from ubuntu jammy
  517. max_tol[15] = 2.0e-2
  518. # relaxed tolerance in gh-20335 for OpenBLAS 3.20 on ubuntu jammy
  519. # bump not needed for OpenBLAS 3.26
  520. max_tol[16] = 2.0e-4
  521. @pytest.mark.parametrize("j, case", enumerate(cases))
  522. def test_solve_discrete_are(self, j, case):
  523. """Checks if X = A'XA-(A'XB)(R+B'XB)^-1(B'XA)+Q) is true"""
  524. a, b, q, r, knownfailure = case
  525. if knownfailure:
  526. pytest.xfail(reason=knownfailure)
  527. atol = self.max_tol[j]
  528. x = solve_discrete_are(a, b, q, r)
  529. bH = b.conj().T
  530. xa, xb = x @ a, x @ b
  531. res = a.conj().T @ xa - x + q
  532. res -= a.conj().T @ xb @ (solve(r + bH @ xb, bH) @ xa)
  533. # changed from
  534. # assert_array_almost_equal(res, np.zeros_like(res), decimal=dec)
  535. # in gh-18012 as it's easier to relax a tolerance and allclose is
  536. # preferred
  537. assert_allclose(res, np.zeros_like(res), atol=atol)
  538. def test_infeasible(self):
  539. # An infeasible example taken from https://arxiv.org/abs/1505.04861v1
  540. A = np.triu(np.ones((3, 3)))
  541. A[0, 1] = -1
  542. B = np.array([[1, 1, 0], [0, 0, 1]]).T
  543. Q = np.full_like(A, -2) + np.diag([8, -1, -1.9])
  544. R = np.diag([-10, 0.1])
  545. assert_raises(LinAlgError, solve_continuous_are, A, B, Q, R)
  546. class TestSolveCommonAre:
  547. @pytest.mark.parametrize("solver", [solve_continuous_are, solve_discrete_are])
  548. def test_with_skipped_array_argument_gh23336(self, solver):
  549. # gh-23336 reported a failure when optional argument `e` was skipped
  550. A = np.array([[-0.9, 0.25], [0, -1.1]])
  551. B = np.array([[0.23], [0.45]])
  552. Q = np.eye(2)
  553. R = np.atleast_2d(0.45)
  554. E = np.eye(2)
  555. S = np.array([[0.1], [0.2]])
  556. res = solver(A, B, Q, R, s=S)
  557. ref = solver(A, B, Q, R, E, S)
  558. np.testing.assert_allclose(res, ref)
  559. def test_solve_generalized_continuous_are():
  560. cases = [
  561. # Two random examples differ by s term
  562. # in the absence of any literature for demanding examples.
  563. (np.array([[2.769230e-01, 8.234578e-01, 9.502220e-01],
  564. [4.617139e-02, 6.948286e-01, 3.444608e-02],
  565. [9.713178e-02, 3.170995e-01, 4.387444e-01]]),
  566. np.array([[3.815585e-01, 1.868726e-01],
  567. [7.655168e-01, 4.897644e-01],
  568. [7.951999e-01, 4.455862e-01]]),
  569. np.eye(3),
  570. np.eye(2),
  571. np.array([[6.463130e-01, 2.760251e-01, 1.626117e-01],
  572. [7.093648e-01, 6.797027e-01, 1.189977e-01],
  573. [7.546867e-01, 6.550980e-01, 4.983641e-01]]),
  574. np.zeros((3, 2)),
  575. None),
  576. (np.array([[2.769230e-01, 8.234578e-01, 9.502220e-01],
  577. [4.617139e-02, 6.948286e-01, 3.444608e-02],
  578. [9.713178e-02, 3.170995e-01, 4.387444e-01]]),
  579. np.array([[3.815585e-01, 1.868726e-01],
  580. [7.655168e-01, 4.897644e-01],
  581. [7.951999e-01, 4.455862e-01]]),
  582. np.eye(3),
  583. np.eye(2),
  584. np.array([[6.463130e-01, 2.760251e-01, 1.626117e-01],
  585. [7.093648e-01, 6.797027e-01, 1.189977e-01],
  586. [7.546867e-01, 6.550980e-01, 4.983641e-01]]),
  587. np.ones((3, 2)),
  588. None)
  589. ]
  590. min_decimal = (10, 10)
  591. def _test_factory(case, dec):
  592. """Checks if X = A'XA-(A'XB)(R+B'XB)^-1(B'XA)+Q) is true"""
  593. a, b, q, r, e, s, knownfailure = case
  594. if knownfailure:
  595. pytest.xfail(reason=knownfailure)
  596. x = solve_continuous_are(a, b, q, r, e, s)
  597. res = a.conj().T.dot(x.dot(e)) + e.conj().T.dot(x.dot(a)) + q
  598. out_fact = e.conj().T.dot(x).dot(b) + s
  599. res -= out_fact.dot(solve(np.atleast_2d(r), out_fact.conj().T))
  600. assert_array_almost_equal(res, np.zeros_like(res), decimal=dec)
  601. for ind, case in enumerate(cases):
  602. _test_factory(case, min_decimal[ind])
  603. def test_solve_generalized_discrete_are():
  604. mat20170120 = _load_data('gendare_20170120_data.npz')
  605. cases = [
  606. # Two random examples differ by s term
  607. # in the absence of any literature for demanding examples.
  608. (np.array([[2.769230e-01, 8.234578e-01, 9.502220e-01],
  609. [4.617139e-02, 6.948286e-01, 3.444608e-02],
  610. [9.713178e-02, 3.170995e-01, 4.387444e-01]]),
  611. np.array([[3.815585e-01, 1.868726e-01],
  612. [7.655168e-01, 4.897644e-01],
  613. [7.951999e-01, 4.455862e-01]]),
  614. np.eye(3),
  615. np.eye(2),
  616. np.array([[6.463130e-01, 2.760251e-01, 1.626117e-01],
  617. [7.093648e-01, 6.797027e-01, 1.189977e-01],
  618. [7.546867e-01, 6.550980e-01, 4.983641e-01]]),
  619. np.zeros((3, 2)),
  620. None),
  621. (np.array([[2.769230e-01, 8.234578e-01, 9.502220e-01],
  622. [4.617139e-02, 6.948286e-01, 3.444608e-02],
  623. [9.713178e-02, 3.170995e-01, 4.387444e-01]]),
  624. np.array([[3.815585e-01, 1.868726e-01],
  625. [7.655168e-01, 4.897644e-01],
  626. [7.951999e-01, 4.455862e-01]]),
  627. np.eye(3),
  628. np.eye(2),
  629. np.array([[6.463130e-01, 2.760251e-01, 1.626117e-01],
  630. [7.093648e-01, 6.797027e-01, 1.189977e-01],
  631. [7.546867e-01, 6.550980e-01, 4.983641e-01]]),
  632. np.ones((3, 2)),
  633. None),
  634. # user-reported (under PR-6616) 20-Jan-2017
  635. # tests against the case where E is None but S is provided
  636. (mat20170120['A'],
  637. mat20170120['B'],
  638. mat20170120['Q'],
  639. mat20170120['R'],
  640. None,
  641. mat20170120['S'],
  642. None),
  643. ]
  644. max_atol = (1.5e-11, 1.5e-11, 3.5e-16)
  645. def _test_factory(case, atol):
  646. """Checks if X = A'XA-(A'XB)(R+B'XB)^-1(B'XA)+Q) is true"""
  647. a, b, q, r, e, s, knownfailure = case
  648. if knownfailure:
  649. pytest.xfail(reason=knownfailure)
  650. x = solve_discrete_are(a, b, q, r, e, s)
  651. if e is None:
  652. e = np.eye(a.shape[0])
  653. if s is None:
  654. s = np.zeros_like(b)
  655. res = a.conj().T.dot(x.dot(a)) - e.conj().T.dot(x.dot(e)) + q
  656. res -= (a.conj().T.dot(x.dot(b)) + s).dot(
  657. solve(r+b.conj().T.dot(x.dot(b)),
  658. (b.conj().T.dot(x.dot(a)) + s.conj().T)
  659. )
  660. )
  661. # changed from:
  662. # assert_array_almost_equal(res, np.zeros_like(res), decimal=dec)
  663. # in gh-17950 because of a Linux 32 bit fail.
  664. assert_allclose(res, np.zeros_like(res), atol=atol)
  665. for ind, case in enumerate(cases):
  666. _test_factory(case, max_atol[ind])
  667. def test_are_validate_args():
  668. def test_square_shape():
  669. nsq = np.ones((3, 2))
  670. sq = np.eye(3)
  671. for x in (solve_continuous_are, solve_discrete_are):
  672. assert_raises(ValueError, x, nsq, 1, 1, 1)
  673. assert_raises(ValueError, x, sq, sq, nsq, 1)
  674. assert_raises(ValueError, x, sq, sq, sq, nsq)
  675. assert_raises(ValueError, x, sq, sq, sq, sq, nsq)
  676. def test_compatible_sizes():
  677. nsq = np.ones((3, 2))
  678. sq = np.eye(4)
  679. for x in (solve_continuous_are, solve_discrete_are):
  680. assert_raises(ValueError, x, sq, nsq, 1, 1)
  681. assert_raises(ValueError, x, sq, sq, sq, sq, sq, nsq)
  682. assert_raises(ValueError, x, sq, sq, np.eye(3), sq)
  683. assert_raises(ValueError, x, sq, sq, sq, np.eye(3))
  684. assert_raises(ValueError, x, sq, sq, sq, sq, np.eye(3))
  685. def test_symmetry():
  686. nsym = np.arange(9).reshape(3, 3)
  687. sym = np.eye(3)
  688. for x in (solve_continuous_are, solve_discrete_are):
  689. assert_raises(ValueError, x, sym, sym, nsym, sym)
  690. assert_raises(ValueError, x, sym, sym, sym, nsym)
  691. def test_singularity():
  692. sing = np.full((3, 3), 1e12)
  693. sing[2, 2] -= 1
  694. sq = np.eye(3)
  695. for x in (solve_continuous_are, solve_discrete_are):
  696. assert_raises(ValueError, x, sq, sq, sq, sq, sing)
  697. assert_raises(ValueError, solve_continuous_are, sq, sq, sq, sing)
  698. def test_finiteness():
  699. nm = np.full((2, 2), np.nan)
  700. sq = np.eye(2)
  701. for x in (solve_continuous_are, solve_discrete_are):
  702. assert_raises(ValueError, x, nm, sq, sq, sq)
  703. assert_raises(ValueError, x, sq, nm, sq, sq)
  704. assert_raises(ValueError, x, sq, sq, nm, sq)
  705. assert_raises(ValueError, x, sq, sq, sq, nm)
  706. assert_raises(ValueError, x, sq, sq, sq, sq, nm)
  707. assert_raises(ValueError, x, sq, sq, sq, sq, sq, nm)
  708. class TestSolveSylvester:
  709. cases = [
  710. # empty cases
  711. (np.empty((0, 0)),
  712. np.empty((0, 0)),
  713. np.empty((0, 0))),
  714. (np.empty((0, 0)),
  715. np.empty((2, 2)),
  716. np.empty((0, 2))),
  717. (np.empty((2, 2)),
  718. np.empty((0, 0)),
  719. np.empty((2, 0))),
  720. # a, b, c all real.
  721. (np.array([[1, 2], [0, 4]]),
  722. np.array([[5, 6], [0, 8]]),
  723. np.array([[9, 10], [11, 12]])),
  724. # a, b, c all real, 4x4. a and b have non-trivial 2x2 blocks in their
  725. # quasi-triangular form.
  726. (np.array([[1.0, 0, 0, 0],
  727. [0, 1.0, 2.0, 0.0],
  728. [0, 0, 3.0, -4],
  729. [0, 0, 2, 5]]),
  730. np.array([[2.0, 0, 0, 1.0],
  731. [0, 1.0, 0.0, 0.0],
  732. [0, 0, 1.0, -1],
  733. [0, 0, 1, 1]]),
  734. np.array([[1.0, 0, 0, 0],
  735. [0, 1.0, 0, 0],
  736. [0, 0, 1.0, 0],
  737. [0, 0, 0, 1.0]])),
  738. # a, b, c all complex.
  739. (np.array([[1.0+1j, 2.0], [3.0-4.0j, 5.0]]),
  740. np.array([[-1.0, 2j], [3.0, 4.0]]),
  741. np.array([[2.0-2j, 2.0+2j], [-1.0-1j, 2.0]])),
  742. # a and b real; c complex.
  743. (np.array([[1.0, 2.0], [3.0, 5.0]]),
  744. np.array([[-1.0, 0], [3.0, 4.0]]),
  745. np.array([[2.0-2j, 2.0+2j], [-1.0-1j, 2.0]])),
  746. # a and c complex; b real.
  747. (np.array([[1.0+1j, 2.0], [3.0-4.0j, 5.0]]),
  748. np.array([[-1.0, 0], [3.0, 4.0]]),
  749. np.array([[2.0-2j, 2.0+2j], [-1.0-1j, 2.0]])),
  750. # a complex; b and c real.
  751. (np.array([[1.0+1j, 2.0], [3.0-4.0j, 5.0]]),
  752. np.array([[-1.0, 0], [3.0, 4.0]]),
  753. np.array([[2.0, 2.0], [-1.0, 2.0]])),
  754. # not square matrices, real
  755. (np.array([[8, 1, 6], [3, 5, 7], [4, 9, 2]]),
  756. np.array([[2, 3], [4, 5]]),
  757. np.array([[1, 2], [3, 4], [5, 6]])),
  758. # not square matrices, complex
  759. (np.array([[8, 1j, 6+2j], [3, 5, 7], [4, 9, 2]]),
  760. np.array([[2, 3], [4, 5-1j]]),
  761. np.array([[1, 2j], [3, 4j], [5j, 6+7j]])),
  762. ]
  763. def check_case(self, a, b, c):
  764. x = solve_sylvester(a, b, c)
  765. assert_array_almost_equal(np.dot(a, x) + np.dot(x, b), c)
  766. def test_cases(self):
  767. for case in self.cases:
  768. self.check_case(case[0], case[1], case[2])
  769. def test_trivial(self):
  770. a = np.array([[1.0, 0.0], [0.0, 1.0]])
  771. b = np.array([[1.0]])
  772. c = np.array([2.0, 2.0]).reshape(-1, 1)
  773. x = solve_sylvester(a, b, c)
  774. assert_array_almost_equal(x, np.array([1.0, 1.0]).reshape(-1, 1))
  775. # Feel free to adjust this to test fewer dtypes or random selections rather than
  776. # the Cartesian product. It doesn't take very long to test all combinations,
  777. # though, so we'll start there and trim it down as we see fit.
  778. @pytest.mark.parametrize("dtype_a", dtypes)
  779. @pytest.mark.parametrize("dtype_b", dtypes)
  780. @pytest.mark.parametrize("dtype_q", dtypes)
  781. @pytest.mark.parametrize("m", [0, 3])
  782. @pytest.mark.parametrize("n", [0, 3])
  783. def test_size_0(self, m, n, dtype_a, dtype_b, dtype_q):
  784. if m == n != 0:
  785. pytest.skip('m = n != 0 is not a case that needs to be tested here.')
  786. rng = np.random.default_rng(598435298262546)
  787. a = np.zeros((m, m), dtype=dtype_a)
  788. b = np.zeros((n, n), dtype=dtype_b)
  789. q = np.zeros((m, n), dtype=dtype_q)
  790. res = solve_sylvester(a, b, q)
  791. a = (rng.random((5, 5))*100).astype(dtype_a)
  792. b = (rng.random((6, 6))*100).astype(dtype_b)
  793. q = (rng.random((5, 6))*100).astype(dtype_q)
  794. ref = solve_sylvester(a, b, q)
  795. assert res.shape == (m, n)
  796. assert res.dtype == ref.dtype