test_matfuncs.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121
  1. #
  2. # Created by: Pearu Peterson, March 2002
  3. #
  4. """ Test functions for linalg.matfuncs module
  5. """
  6. import functools
  7. import pytest
  8. import warnings
  9. import numpy as np
  10. from numpy import array, identity, sqrt
  11. from numpy.testing import (assert_array_almost_equal, assert_allclose, assert_,
  12. assert_array_less, assert_array_equal)
  13. import scipy.linalg
  14. from scipy.linalg import (funm, signm, logm, sqrtm, fractional_matrix_power,
  15. expm, expm_frechet, expm_cond, norm, khatri_rao,
  16. cosm, sinm, tanm, coshm, sinhm, tanhm)
  17. from scipy.linalg import _matfuncs_inv_ssq
  18. from scipy.linalg._matfuncs import pick_pade_structure
  19. from scipy.linalg._matfuncs_inv_ssq import LogmExactlySingularWarning
  20. import scipy.linalg._expm_frechet
  21. from scipy.linalg import LinAlgWarning
  22. from scipy.optimize import minimize
  23. def _get_al_mohy_higham_2012_experiment_1():
  24. """
  25. Return the test matrix from Experiment (1) of [1]_.
  26. References
  27. ----------
  28. .. [1] Awad H. Al-Mohy and Nicholas J. Higham (2012)
  29. "Improved Inverse Scaling and Squaring Algorithms
  30. for the Matrix Logarithm."
  31. SIAM Journal on Scientific Computing, 34 (4). C152-C169.
  32. ISSN 1095-7197
  33. """
  34. A = np.array([
  35. [3.2346e-1, 3e4, 3e4, 3e4],
  36. [0, 3.0089e-1, 3e4, 3e4],
  37. [0, 0, 3.2210e-1, 3e4],
  38. [0, 0, 0, 3.0744e-1]], dtype=float)
  39. return A
  40. class TestSignM:
  41. def test_nils(self):
  42. a = array([[29.2, -24.2, 69.5, 49.8, 7.],
  43. [-9.2, 5.2, -18., -16.8, -2.],
  44. [-10., 6., -20., -18., -2.],
  45. [-9.6, 9.6, -25.5, -15.4, -2.],
  46. [9.8, -4.8, 18., 18.2, 2.]])
  47. cr = array([[11.94933333,-2.24533333,15.31733333,21.65333333,-2.24533333],
  48. [-3.84266667,0.49866667,-4.59066667,-7.18666667,0.49866667],
  49. [-4.08,0.56,-4.92,-7.6,0.56],
  50. [-4.03466667,1.04266667,-5.59866667,-7.02666667,1.04266667],
  51. [4.15733333,-0.50133333,4.90933333,7.81333333,-0.50133333]])
  52. r = signm(a)
  53. assert_array_almost_equal(r,cr)
  54. def test_defective1(self):
  55. a = array([[0.0,1,0,0],[1,0,1,0],[0,0,0,1],[0,0,1,0]])
  56. signm(a)
  57. #XXX: what would be the correct result?
  58. def test_defective2(self):
  59. a = array((
  60. [29.2,-24.2,69.5,49.8,7.0],
  61. [-9.2,5.2,-18.0,-16.8,-2.0],
  62. [-10.0,6.0,-20.0,-18.0,-2.0],
  63. [-9.6,9.6,-25.5,-15.4,-2.0],
  64. [9.8,-4.8,18.0,18.2,2.0]))
  65. signm(a)
  66. #XXX: what would be the correct result?
  67. def test_defective3(self):
  68. a = array([[-2., 25., 0., 0., 0., 0., 0.],
  69. [0., -3., 10., 3., 3., 3., 0.],
  70. [0., 0., 2., 15., 3., 3., 0.],
  71. [0., 0., 0., 0., 15., 3., 0.],
  72. [0., 0., 0., 0., 3., 10., 0.],
  73. [0., 0., 0., 0., 0., -2., 25.],
  74. [0., 0., 0., 0., 0., 0., -3.]])
  75. signm(a)
  76. #XXX: what would be the correct result?
  77. class TestLogM:
  78. @pytest.mark.filterwarnings("ignore:.*inaccurate.*:RuntimeWarning")
  79. def test_nils(self):
  80. a = array([[-2., 25., 0., 0., 0., 0., 0.],
  81. [0., -3., 10., 3., 3., 3., 0.],
  82. [0., 0., 2., 15., 3., 3., 0.],
  83. [0., 0., 0., 0., 15., 3., 0.],
  84. [0., 0., 0., 0., 3., 10., 0.],
  85. [0., 0., 0., 0., 0., -2., 25.],
  86. [0., 0., 0., 0., 0., 0., -3.]])
  87. m = (identity(7)*3.1+0j)-a
  88. logm(m)
  89. #XXX: what would be the correct result?
  90. @pytest.mark.filterwarnings("ignore:.*inaccurate.*:RuntimeWarning")
  91. def test_al_mohy_higham_2012_experiment_1_logm(self):
  92. # The logm completes the round trip successfully.
  93. # Note that the expm leg of the round trip is badly conditioned.
  94. A = _get_al_mohy_higham_2012_experiment_1()
  95. A_logm = logm(A)
  96. A_round_trip = expm(A_logm)
  97. assert_allclose(A_round_trip, A, rtol=5e-5, atol=1e-14)
  98. def test_al_mohy_higham_2012_experiment_1_funm_log(self):
  99. # The raw funm with np.log does not complete the round trip.
  100. # Note that the expm leg of the round trip is badly conditioned.
  101. A = _get_al_mohy_higham_2012_experiment_1()
  102. A_funm_log = funm(A, np.log)
  103. A_round_trip = expm(A_funm_log)
  104. assert_(not np.allclose(A_round_trip, A, rtol=1e-5, atol=1e-14))
  105. def test_round_trip_random_float(self):
  106. rng = np.random.default_rng(1738098768840254)
  107. for n in range(1, 6):
  108. M_unscaled = rng.uniform(size=(n, n))
  109. for scale in np.logspace(-4, 4, 9):
  110. M = M_unscaled * scale
  111. # Eigenvalues are related to the branch cut.
  112. W = np.linalg.eigvals(M)
  113. err_msg = f'M:{M} eivals:{W}'
  114. # Check sqrtm round trip because it is used within logm.
  115. M_sqrtm = sqrtm(M)
  116. M_sqrtm_round_trip = M_sqrtm @ M_sqrtm
  117. assert_allclose(M_sqrtm_round_trip, M)
  118. # Check logm round trip.
  119. with warnings.catch_warnings():
  120. warnings.simplefilter("ignore", RuntimeWarning)
  121. M_logm = logm(M)
  122. M_logm_round_trip = expm(M_logm)
  123. assert_allclose(M_logm_round_trip, M, err_msg=err_msg)
  124. def test_round_trip_random_complex(self):
  125. rng = np.random.default_rng(1738098768840254)
  126. for n in range(1, 6):
  127. M_unscaled = (rng.standard_normal((n, n)) +
  128. 1j*rng.standard_normal((n, n)))
  129. for scale in np.logspace(-4, 4, 9):
  130. M = M_unscaled * scale
  131. M_logm = logm(M)
  132. M_round_trip = expm(M_logm)
  133. assert_allclose(M_round_trip, M)
  134. def test_logm_type_preservation_and_conversion(self):
  135. # The logm matrix function should preserve the type of a matrix
  136. # whose eigenvalues are positive with zero imaginary part.
  137. # Test this preservation for variously structured matrices.
  138. complex_dtype_chars = ('F', 'D', 'G')
  139. for matrix_as_list in (
  140. [[1, 0], [0, 1]],
  141. [[1, 0], [1, 1]],
  142. [[2, 1], [1, 1]],
  143. [[2, 3], [1, 2]]):
  144. # check that the spectrum has the expected properties
  145. W = scipy.linalg.eigvals(matrix_as_list)
  146. assert_(not any(w.imag or w.real < 0 for w in W))
  147. # check float type preservation
  148. A = np.array(matrix_as_list, dtype=float)
  149. A_logm = logm(A)
  150. assert_(A_logm.dtype.char not in complex_dtype_chars)
  151. # check complex type preservation
  152. A = np.array(matrix_as_list, dtype=complex)
  153. A_logm = logm(A)
  154. assert_(A_logm.dtype.char in complex_dtype_chars)
  155. # check float->complex type conversion for the matrix negation
  156. A = -np.array(matrix_as_list, dtype=float)
  157. A_logm = logm(A)
  158. assert_(A_logm.dtype.char in complex_dtype_chars)
  159. def test_complex_spectrum_real_logm(self):
  160. # This matrix has complex eigenvalues and real logm.
  161. # Its output dtype depends on its input dtype.
  162. M = [[1, 1, 2], [2, 1, 1], [1, 2, 1]]
  163. for dt in float, complex:
  164. X = np.array(M, dtype=dt)
  165. w = scipy.linalg.eigvals(X)
  166. assert_(1e-2 < np.absolute(w.imag).sum())
  167. Y = logm(X)
  168. assert_(np.issubdtype(Y.dtype, np.inexact))
  169. assert_allclose(expm(Y), X)
  170. def test_real_mixed_sign_spectrum(self):
  171. # These matrices have real eigenvalues with mixed signs.
  172. # The output logm dtype is complex, regardless of input dtype.
  173. for M in (
  174. [[1, 0], [0, -1]],
  175. [[0, 1], [1, 0]]):
  176. for dt in float, complex:
  177. A = np.array(M, dtype=dt)
  178. A_logm, info = logm(A)
  179. assert_(np.issubdtype(A_logm.dtype, np.complexfloating))
  180. def test_exactly_singular(self):
  181. A = np.array([[0, 0], [1j, 1j]])
  182. B = np.asarray([[1, 1], [0, 0]])
  183. for M in A, A.T, B, B.T:
  184. with pytest.warns(_matfuncs_inv_ssq.LogmExactlySingularWarning):
  185. L = logm(M)
  186. E = expm(L)
  187. assert_allclose(E, M, atol=1e-14)
  188. def test_nearly_singular(self):
  189. M = np.array([[1e-100]])
  190. with pytest.warns(_matfuncs_inv_ssq.LogmNearlySingularWarning):
  191. L = logm(M)
  192. E = expm(L)
  193. assert_allclose(E, M, atol=1e-14)
  194. def test_opposite_sign_complex_eigenvalues(self):
  195. # See gh-6113
  196. E = [[0, 1], [-1, 0]]
  197. L = [[0, np.pi*0.5], [-np.pi*0.5, 0]]
  198. assert_allclose(expm(L), E, atol=1e-14)
  199. assert_allclose(logm(E), L, atol=1e-14)
  200. E = [[1j, 4], [0, -1j]]
  201. L = [[1j*np.pi*0.5, 2*np.pi], [0, -1j*np.pi*0.5]]
  202. assert_allclose(expm(L), E, atol=1e-14)
  203. assert_allclose(logm(E), L, atol=1e-14)
  204. E = [[1j, 0], [0, -1j]]
  205. L = [[1j*np.pi*0.5, 0], [0, -1j*np.pi*0.5]]
  206. assert_allclose(expm(L), E, atol=1e-14)
  207. assert_allclose(logm(E), L, atol=1e-14)
  208. def test_readonly(self):
  209. n = 5
  210. a = np.ones((n, n)) + np.identity(n)
  211. a.flags.writeable = False
  212. logm(a)
  213. @pytest.mark.xfail(reason="ValueError: attempt to get argmax of an empty sequence")
  214. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  215. def test_empty(self, dt):
  216. a = np.empty((0, 0), dtype=dt)
  217. log_a = logm(a)
  218. a0 = np.eye(2, dtype=dt)
  219. log_a0 = logm(a0)
  220. assert log_a.shape == (0, 0)
  221. assert log_a.dtype == log_a0.dtype
  222. @pytest.mark.parametrize('dtype', [int, float, np.float32, complex, np.complex64])
  223. def test_no_ZeroDivisionError(self, dtype):
  224. # gh-17136 reported inconsistent behavior in `logm` depending on input dtype:
  225. # sometimes it raised an error, and sometimes it printed a warning message.
  226. # check that this is resolved and that the warning is emitted properly.
  227. with (pytest.warns(RuntimeWarning, match="logm result may be inaccurate"),
  228. pytest.warns(LogmExactlySingularWarning)):
  229. logm(np.zeros((2, 2), dtype=dtype))
  230. class TestSqrtM:
  231. def test_round_trip_random_float(self):
  232. rng = np.random.default_rng(1738151906092735)
  233. for n in range(1, 6):
  234. M_unscaled = rng.standard_normal((n, n))
  235. for scale in np.logspace(-4, 4, 9):
  236. M = M_unscaled * scale
  237. M_sqrtm = sqrtm(M)
  238. M_sqrtm_round_trip = M_sqrtm.dot(M_sqrtm)
  239. assert_allclose(M_sqrtm_round_trip, M)
  240. def test_round_trip_random_complex(self):
  241. rng = np.random.default_rng(1738151906092736)
  242. for n in range(1, 6):
  243. M_unscaled = (rng.standard_normal((n, n)) +
  244. 1j * rng.standard_normal((n, n)))
  245. for scale in np.logspace(-4, 4, 9):
  246. M = M_unscaled * scale
  247. M_sqrtm = sqrtm(M)
  248. M_sqrtm_round_trip = M_sqrtm.dot(M_sqrtm)
  249. assert_allclose(M_sqrtm_round_trip, M)
  250. def test_bad(self):
  251. # See https://web.archive.org/web/20051220232650/http://www.maths.man.ac.uk/~nareports/narep336.ps.gz
  252. e = 2**-5
  253. se = sqrt(e)
  254. a = array([[1.0,0,0,1],
  255. [0,e,0,0],
  256. [0,0,e,0],
  257. [0,0,0,1]])
  258. sa = array([[1,0,0,0.5],
  259. [0,se,0,0],
  260. [0,0,se,0],
  261. [0,0,0,1]])
  262. assert_array_almost_equal(sa @ sa, a)
  263. # Check default sqrtm.
  264. esa = sqrtm(a)
  265. assert_array_almost_equal(esa @ esa, a)
  266. def test_sqrtm_type_preservation_and_conversion(self):
  267. # The sqrtm matrix function should preserve the type of a matrix
  268. # whose eigenvalues are nonnegative with zero imaginary part.
  269. # Test this preservation for variously structured matrices.
  270. complex_dtype_chars = ('F', 'D', 'G')
  271. for matrix_as_list in (
  272. [[1, 0], [0, 1]],
  273. [[1, 0], [1, 1]],
  274. [[2, 1], [1, 1]],
  275. [[2, 3], [1, 2]],
  276. [[1, 1], [1, 1]]):
  277. # check that the spectrum has the expected properties
  278. W = scipy.linalg.eigvals(matrix_as_list)
  279. assert_(not any(w.imag or w.real < 0 for w in W))
  280. # Last test matrix is singular so suppress the warning
  281. with warnings.catch_warnings():
  282. warnings.simplefilter("ignore", LinAlgWarning)
  283. # check float type preservation
  284. A = np.array(matrix_as_list, dtype=float)
  285. A_sqrtm = sqrtm(A)
  286. assert_(A_sqrtm.dtype.char not in complex_dtype_chars)
  287. # check complex type preservation
  288. A = np.array(matrix_as_list, dtype=complex)
  289. A_sqrtm = sqrtm(A)
  290. assert_(A_sqrtm.dtype.char in complex_dtype_chars)
  291. # check float->complex type conversion for the matrix negation
  292. A = -np.array(matrix_as_list, dtype=float)
  293. A_sqrtm = sqrtm(A)
  294. assert_(A_sqrtm.dtype.char in complex_dtype_chars)
  295. def test_sqrtm_type_conversion_mixed_sign_or_complex_spectrum(self):
  296. complex_dtype_chars = ('F', 'D', 'G')
  297. for matrix_as_list in (
  298. [[1, 0], [0, -1]],
  299. [[0, 1], [1, 0]],
  300. [[0, 1, 0], [0, 0, -1], [1, 0, 0]]):
  301. # check that the spectrum has the expected properties
  302. W = scipy.linalg.eigvals(matrix_as_list)
  303. assert_(any(w.imag or w.real < 0 for w in W))
  304. # check complex->complex
  305. A = np.array(matrix_as_list, dtype=complex)
  306. A_sqrtm = sqrtm(A)
  307. assert_(A_sqrtm.dtype.char in complex_dtype_chars)
  308. # check float->complex
  309. A = np.array(matrix_as_list, dtype=float)
  310. A_sqrtm = sqrtm(A)
  311. assert_(A_sqrtm.dtype.char in complex_dtype_chars)
  312. def test_al_mohy_higham_2012_experiment_1(self):
  313. # Matrix square root of a tricky upper triangular matrix.
  314. A = _get_al_mohy_higham_2012_experiment_1()
  315. A_sqrtm = sqrtm(A)
  316. A_round_trip = A_sqrtm @ A_sqrtm
  317. assert_allclose(A_round_trip, A, rtol=1e-5)
  318. assert_allclose(np.tril(A_round_trip), np.tril(A))
  319. def test_strict_upper_triangular(self):
  320. # This matrix has no square root but upper triangular hence upper
  321. # triangle will be filled with junk values.
  322. for dt in int, float:
  323. A = np.array([
  324. [0, 3, 0, 0],
  325. [0, 0, 3, 0],
  326. [0, 0, 0, 3],
  327. [0, 0, 0, 0]], dtype=dt)
  328. with warnings.catch_warnings():
  329. warnings.simplefilter("ignore", LinAlgWarning)
  330. A_sqrtm = sqrtm(A)
  331. assert_allclose(np.tril(A_sqrtm), np.zeros((4, 4)))
  332. assert np.isnan(A_sqrtm).any()
  333. assert np.isinf(A_sqrtm).any()
  334. # Future edit: This squareroot is not possible to find algorithmically
  335. # with the current methods. Now sqrtm docstring has another example of
  336. # such matrix whose squareroot is not a polynomial in it. Hence no need
  337. # to test it here.
  338. """
  339. def test_weird_matrix(self):
  340. # The square root of matrix B exists.
  341. for dt in int, float:
  342. A = np.array([
  343. [0, 0, 1],
  344. [0, 0, 0],
  345. [0, 1, 0]], dtype=dt)
  346. B = np.array([
  347. [0, 1, 0],
  348. [0, 0, 0],
  349. [0, 0, 0]], dtype=dt)
  350. assert_array_equal(B, A @ A)
  351. # But scipy sqrtm is not clever enough to find it.
  352. B_sqrtm, info = sqrtm(B, disp=False)
  353. assert_(np.isnan(B_sqrtm).all())
  354. """
  355. def test_opposite_sign_complex_eigenvalues(self):
  356. M = [[2j, 4], [0, -2j]]
  357. R = [[1+1j, 2], [0, 1-1j]]
  358. assert_allclose(np.dot(R, R), M, atol=1e-14)
  359. assert_allclose(sqrtm(M), R, atol=1e-14)
  360. def test_gh4866(self):
  361. M = np.array([[1, 0, 0, 1],
  362. [0, 0, 0, 0],
  363. [0, 0, 0, 0],
  364. [1, 0, 0, 1]])
  365. R = np.array([[sqrt(0.5), 0, 0, sqrt(0.5)],
  366. [0, 0, 0, 0],
  367. [0, 0, 0, 0],
  368. [sqrt(0.5), 0, 0, sqrt(0.5)]])
  369. assert_allclose(R @ R, M, atol=1e-14)
  370. with warnings.catch_warnings():
  371. warnings.simplefilter("ignore", LinAlgWarning)
  372. assert_allclose(sqrtm(M), R, atol=1e-14)
  373. def test_gh5336(self):
  374. M = np.diag([2, 1, 0])
  375. R = np.diag([sqrt(2), 1, 0])
  376. assert_allclose(R @ R, M, atol=1e-14)
  377. with warnings.catch_warnings():
  378. warnings.filterwarnings("ignore", category=LinAlgWarning)
  379. assert_allclose(sqrtm(M), R, atol=1e-14)
  380. def test_gh7839(self):
  381. M = np.zeros((2, 2))
  382. R = np.zeros((2, 2))
  383. # Catch and silence LinAlgWarning
  384. with warnings.catch_warnings():
  385. warnings.filterwarnings("ignore", category=LinAlgWarning)
  386. assert_allclose(sqrtm(M), R, atol=1e-14)
  387. def test_gh17918(self):
  388. M = np.empty((19, 19))
  389. M.fill(0.94)
  390. np.fill_diagonal(M, 1)
  391. assert np.isrealobj(sqrtm(M))
  392. def test_gh23278(self):
  393. M = np.array([[1., 0., 0.], [0, 1, -1j], [0, 1j, 2]])
  394. sq = sqrtm(M)
  395. assert_allclose(sq @ sq, M, atol=1e-14)
  396. sq = sqrtm(M.astype(np.complex64))
  397. assert_allclose(sq @ sq, M, atol=1e-6)
  398. def test_data_size_preservation_uint_in_float_out(self):
  399. M = np.eye(10, dtype=np.uint8)
  400. assert sqrtm(M).dtype == np.float64
  401. M = np.eye(10, dtype=np.uint16)
  402. assert sqrtm(M).dtype == np.float64
  403. M = np.eye(10, dtype=np.uint32)
  404. assert sqrtm(M).dtype == np.float64
  405. M = np.eye(10, dtype=np.uint64)
  406. assert sqrtm(M).dtype == np.float64
  407. def test_data_size_preservation_int_in_float_out(self):
  408. M = np.eye(10, dtype=np.int8)
  409. assert sqrtm(M).dtype == np.float64
  410. M = np.eye(10, dtype=np.int16)
  411. assert sqrtm(M).dtype == np.float64
  412. M = np.eye(10, dtype=np.int32)
  413. assert sqrtm(M).dtype == np.float64
  414. M = np.eye(10, dtype=np.int64)
  415. assert sqrtm(M).dtype == np.float64
  416. def test_data_size_preservation_int_in_comp_out(self):
  417. M = np.array([[2, 4], [0, -2]], dtype=np.int8)
  418. assert sqrtm(M).dtype == np.complex128
  419. M = np.array([[2, 4], [0, -2]], dtype=np.int16)
  420. assert sqrtm(M).dtype == np.complex128
  421. M = np.array([[2, 4], [0, -2]], dtype=np.int32)
  422. assert sqrtm(M).dtype == np.complex128
  423. M = np.array([[2, 4], [0, -2]], dtype=np.int64)
  424. assert sqrtm(M).dtype == np.complex128
  425. def test_data_size_preservation_float_in_float_out(self):
  426. M = np.eye(10, dtype=np.float16)
  427. assert sqrtm(M).dtype == np.float32
  428. M = np.eye(10, dtype=np.float32)
  429. assert sqrtm(M).dtype == np.float32
  430. M = np.eye(10, dtype=np.float64)
  431. assert sqrtm(M).dtype == np.float64
  432. if hasattr(np, 'float128'):
  433. M = np.eye(10, dtype=np.float128)
  434. assert sqrtm(M).dtype == np.float64
  435. def test_data_size_preservation_float_in_comp_out(self):
  436. M = np.array([[2, 4], [0, -2]], dtype=np.float16)
  437. assert sqrtm(M).dtype == np.complex64
  438. M = np.array([[2, 4], [0, -2]], dtype=np.float32)
  439. assert sqrtm(M).dtype == np.complex64
  440. M = np.array([[2, 4], [0, -2]], dtype=np.float64)
  441. assert sqrtm(M).dtype == np.complex128
  442. if hasattr(np, 'float128') and hasattr(np, 'complex256'):
  443. M = np.array([[2, 4], [0, -2]], dtype=np.float128)
  444. assert sqrtm(M).dtype == np.complex128
  445. def test_data_size_preservation_comp_in_comp_out(self):
  446. M = np.array([[2j, 4], [0, -2j]], dtype=np.complex64)
  447. assert sqrtm(M).dtype == np.complex64
  448. M = np.array([[2j, 4], [0, -2j]], dtype=np.complex128)
  449. assert sqrtm(M).dtype == np.complex128
  450. if hasattr(np, 'complex256'):
  451. M = np.array([[2j, 4], [0, -2j]], dtype=np.complex256)
  452. assert sqrtm(M).dtype == np.complex128
  453. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  454. def test_empty(self, dt):
  455. a = np.empty((0, 0), dtype=dt)
  456. s = sqrtm(a)
  457. a0 = np.eye(2, dtype=dt)
  458. s0 = sqrtm(a0)
  459. assert s.shape == (0, 0)
  460. assert s.dtype == s0.dtype
  461. def test_cf_noncontig_nd_inputs(self):
  462. # Check that non-contiguous arrays are handled correctly.
  463. # Generate an L, U pair for invertible random matrix.
  464. rng = np.random.default_rng(1738151906092737)
  465. n = 13
  466. A = rng.uniform(size=(3, 2*n, 2*n))
  467. L, U = np.tril(A, k=-1) + np.eye(2*n), np.triu(A)
  468. A = L @ U
  469. # Create strided views of 3D array.
  470. A_noncontig_c = A[:, ::2, ::2]
  471. A_noncontig_f = np.asfortranarray(A)[:, 1::2, 1::2]
  472. assert_allclose(sqrtm(A[:, ::2, ::2]), sqrtm(A_noncontig_c))
  473. assert_allclose(sqrtm(A[:, 1::2, 1::2]), sqrtm(A_noncontig_f))
  474. def test_empty_sizes(self):
  475. A = np.empty(shape=[4, 0, 5, 5], dtype=float)
  476. assert_array_equal(sqrtm(A), A)
  477. def test_negative_strides(self):
  478. rng = np.random.default_rng(1738151906092738)
  479. A = rng.uniform(size=(3, 5, 5))
  480. A_negneg_orig = A[:, ::-1, ::-1]
  481. A_negneg_copy = A[:, ::-1, ::-1].copy()
  482. assert_allclose(sqrtm(A_negneg_orig), sqrtm(A_negneg_copy))
  483. A_posneg_orig = A[:, :, ::-1]
  484. A_posneg_copy = A[:, :, ::-1].copy()
  485. assert_allclose(sqrtm(A_posneg_orig), sqrtm(A_posneg_copy))
  486. A_negpos_orig = A[:, ::-1, :]
  487. A_negpos_copy = A[:, ::-1, :].copy()
  488. assert_allclose(sqrtm(A_negpos_orig), sqrtm(A_negpos_copy))
  489. class TestFractionalMatrixPower:
  490. def test_round_trip_random_complex(self):
  491. rng = np.random.default_rng(1234)
  492. for p in range(1, 5):
  493. for n in range(1, 5):
  494. M_unscaled = (rng.standard_normal((n, n)) +
  495. 1j * rng.standard_normal((n, n)))
  496. for scale in np.logspace(-4, 4, 9):
  497. M = M_unscaled * scale
  498. M_root = fractional_matrix_power(M, 1/p)
  499. M_round_trip = np.linalg.matrix_power(M_root, p)
  500. assert_allclose(M_round_trip, M)
  501. def test_round_trip_random_float(self):
  502. # This test is more annoying because it can hit the branch cut;
  503. # this happens when the matrix has an eigenvalue
  504. # with no imaginary component and with a real negative component,
  505. # and it means that the principal branch does not exist.
  506. rng = np.random.default_rng(1234)
  507. for p in range(1, 5):
  508. for n in range(1, 5):
  509. M_unscaled = rng.standard_normal((n, n))
  510. for scale in np.logspace(-4, 4, 9):
  511. M = M_unscaled * scale
  512. M_root = fractional_matrix_power(M, 1/p)
  513. M_round_trip = np.linalg.matrix_power(M_root, p)
  514. assert_allclose(M_round_trip, M)
  515. def test_larger_abs_fractional_matrix_powers(self):
  516. rng = np.random.default_rng(1234)
  517. for n in (2, 3, 5):
  518. for i in range(10):
  519. M = rng.standard_normal((n, n)) + 1j * rng.standard_normal((n, n))
  520. M_one_fifth = fractional_matrix_power(M, 0.2)
  521. # Test the round trip.
  522. M_round_trip = np.linalg.matrix_power(M_one_fifth, 5)
  523. assert_allclose(M, M_round_trip)
  524. # Test a large abs fractional power.
  525. X = fractional_matrix_power(M, -5.4)
  526. Y = np.linalg.matrix_power(M_one_fifth, -27)
  527. assert_allclose(X, Y)
  528. # Test another large abs fractional power.
  529. X = fractional_matrix_power(M, 3.8)
  530. Y = np.linalg.matrix_power(M_one_fifth, 19)
  531. assert_allclose(X, Y)
  532. def test_random_matrices_and_powers(self):
  533. # Each independent iteration of this fuzz test picks random parameters.
  534. # It tries to hit some edge cases.
  535. rng = np.random.default_rng(1726500458620605)
  536. nsamples = 20
  537. for i in range(nsamples):
  538. # Sample a matrix size and a random real power.
  539. n = rng.integers(1, 5)
  540. p = rng.random()
  541. # Sample a random real or complex matrix.
  542. matrix_scale = np.exp(rng.integers(-4, 5))
  543. A = rng.random(size=[n, n])
  544. if [True, False][rng.choice(2)]:
  545. A = A + 1j * rng.random(size=[n, n])
  546. A = A * matrix_scale
  547. # Check a couple of analytically equivalent ways
  548. # to compute the fractional matrix power.
  549. # These can be compared because they both use the principal branch.
  550. A_power = fractional_matrix_power(A, p)
  551. A_logm = logm(A)
  552. A_power_expm_logm = expm(A_logm * p)
  553. assert_allclose(A_power, A_power_expm_logm)
  554. def test_al_mohy_higham_2012_experiment_1(self):
  555. # Fractional powers of a tricky upper triangular matrix.
  556. A = _get_al_mohy_higham_2012_experiment_1()
  557. # Test remainder matrix power.
  558. A_funm_sqrt = funm(A, np.sqrt)
  559. A_sqrtm = sqrtm(A)
  560. A_rem_power = _matfuncs_inv_ssq._remainder_matrix_power(A, 0.5)
  561. A_power = fractional_matrix_power(A, 0.5)
  562. assert_allclose(A_rem_power, A_power, rtol=1e-11)
  563. assert_allclose(A_sqrtm, A_power)
  564. assert_allclose(A_sqrtm, A_funm_sqrt)
  565. # Test more fractional powers.
  566. for p in (1/2, 5/3):
  567. A_power = fractional_matrix_power(A, p)
  568. A_round_trip = fractional_matrix_power(A_power, 1/p)
  569. assert_allclose(A_round_trip, A, rtol=1e-2)
  570. assert_allclose(np.tril(A_round_trip, 1), np.tril(A, 1))
  571. def test_briggs_helper_function(self):
  572. rng = np.random.default_rng(1234)
  573. for a in rng.standard_normal(10) + 1j * rng.standard_normal(10):
  574. for k in range(5):
  575. x_observed = _matfuncs_inv_ssq._briggs_helper_function(a, k)
  576. x_expected = a ** np.exp2(-k) - 1
  577. assert_allclose(x_observed, x_expected)
  578. def test_type_preservation_and_conversion(self):
  579. # The fractional_matrix_power matrix function should preserve
  580. # the type of a matrix whose eigenvalues
  581. # are positive with zero imaginary part.
  582. # Test this preservation for variously structured matrices.
  583. complex_dtype_chars = ('F', 'D', 'G')
  584. for matrix_as_list in (
  585. [[1, 0], [0, 1]],
  586. [[1, 0], [1, 1]],
  587. [[2, 1], [1, 1]],
  588. [[2, 3], [1, 2]]):
  589. # check that the spectrum has the expected properties
  590. W = scipy.linalg.eigvals(matrix_as_list)
  591. assert_(not any(w.imag or w.real < 0 for w in W))
  592. # Check various positive and negative powers
  593. # with absolute values bigger and smaller than 1.
  594. for p in (-2.4, -0.9, 0.2, 3.3):
  595. # check float type preservation
  596. A = np.array(matrix_as_list, dtype=float)
  597. A_power = fractional_matrix_power(A, p)
  598. assert_(A_power.dtype.char not in complex_dtype_chars)
  599. # check complex type preservation
  600. A = np.array(matrix_as_list, dtype=complex)
  601. A_power = fractional_matrix_power(A, p)
  602. assert_(A_power.dtype.char in complex_dtype_chars)
  603. # check float->complex for the matrix negation
  604. A = -np.array(matrix_as_list, dtype=float)
  605. A_power = fractional_matrix_power(A, p)
  606. assert_(A_power.dtype.char in complex_dtype_chars)
  607. def test_type_conversion_mixed_sign_or_complex_spectrum(self):
  608. complex_dtype_chars = ('F', 'D', 'G')
  609. for matrix_as_list in (
  610. [[1, 0], [0, -1]],
  611. [[0, 1], [1, 0]],
  612. [[0, 1, 0], [0, 0, 1], [1, 0, 0]]):
  613. # check that the spectrum has the expected properties
  614. W = scipy.linalg.eigvals(matrix_as_list)
  615. assert_(any(w.imag or w.real < 0 for w in W))
  616. # Check various positive and negative powers
  617. # with absolute values bigger and smaller than 1.
  618. for p in (-2.4, -0.9, 0.2, 3.3):
  619. # check complex->complex
  620. A = np.array(matrix_as_list, dtype=complex)
  621. A_power = fractional_matrix_power(A, p)
  622. assert_(A_power.dtype.char in complex_dtype_chars)
  623. # check float->complex
  624. A = np.array(matrix_as_list, dtype=float)
  625. A_power = fractional_matrix_power(A, p)
  626. assert_(A_power.dtype.char in complex_dtype_chars)
  627. @pytest.mark.xfail(reason='Too unstable across LAPACKs.')
  628. def test_singular(self):
  629. # Negative fractional powers do not work with singular matrices.
  630. for matrix_as_list in (
  631. [[0, 0], [0, 0]],
  632. [[1, 1], [1, 1]],
  633. [[1, 2], [3, 6]],
  634. [[0, 0, 0], [0, 1, 1], [0, -1, 1]]):
  635. # Check fractional powers both for float and for complex types.
  636. for newtype in (float, complex):
  637. A = np.array(matrix_as_list, dtype=newtype)
  638. for p in (-0.7, -0.9, -2.4, -1.3):
  639. A_power = fractional_matrix_power(A, p)
  640. assert_(np.isnan(A_power).all())
  641. for p in (0.2, 1.43):
  642. A_power = fractional_matrix_power(A, p)
  643. A_round_trip = fractional_matrix_power(A_power, 1/p)
  644. assert_allclose(A_round_trip, A)
  645. def test_opposite_sign_complex_eigenvalues(self):
  646. M = [[2j, 4], [0, -2j]]
  647. R = [[1+1j, 2], [0, 1-1j]]
  648. assert_allclose(np.dot(R, R), M, atol=1e-14)
  649. assert_allclose(fractional_matrix_power(M, 0.5), R, atol=1e-14)
  650. class TestExpM:
  651. def test_zero(self):
  652. a = array([[0.,0],[0,0]])
  653. assert_array_almost_equal(expm(a),[[1,0],[0,1]])
  654. def test_single_elt(self):
  655. elt = expm(1)
  656. assert_allclose(elt, np.array([[np.e]]))
  657. @pytest.mark.parametrize('func', [expm, cosm, sinm, tanm, coshm, sinhm, tanhm])
  658. @pytest.mark.parametrize('dt',[int, float, np.float32, complex, np.complex64])
  659. @pytest.mark.parametrize('shape', [(0, 0), (1, 1)])
  660. def test_small_empty_matrix_input(self, func, dt, shape):
  661. # regression test for gh-11082 / gh-20372 - test behavior of expm
  662. # and related functions for small and zero-sized arrays.
  663. A = np.zeros(shape, dtype=dt)
  664. A0 = np.zeros((10, 10), dtype=dt)
  665. result = func(A)
  666. result0 = func(A0)
  667. assert result.shape == shape
  668. assert result.dtype == result0.dtype
  669. def test_2x2_input(self):
  670. E = np.e
  671. a = array([[1, 4], [1, 1]])
  672. aa = (E**4 + 1)/(2*E)
  673. bb = (E**4 - 1)/E
  674. assert_allclose(expm(a), array([[aa, bb], [bb/4, aa]]))
  675. assert expm(a.astype(np.complex64)).dtype.char == 'F'
  676. assert expm(a.astype(np.float32)).dtype.char == 'f'
  677. def test_nx2x2_input(self):
  678. E = np.e
  679. # These are integer matrices with integer eigenvalues
  680. a = np.array([[[1, 4], [1, 1]],
  681. [[1, 3], [1, -1]],
  682. [[1, 3], [4, 5]],
  683. [[1, 3], [5, 3]],
  684. [[4, 5], [-3, -4]]], order='F')
  685. # Exact results are computed symbolically
  686. a_res = np.array([
  687. [[(E**4+1)/(2*E), (E**4-1)/E],
  688. [(E**4-1)/4/E, (E**4+1)/(2*E)]],
  689. [[1/(4*E**2)+(3*E**2)/4, (3*E**2)/4-3/(4*E**2)],
  690. [E**2/4-1/(4*E**2), 3/(4*E**2)+E**2/4]],
  691. [[3/(4*E)+E**7/4, -3/(8*E)+(3*E**7)/8],
  692. [-1/(2*E)+E**7/2, 1/(4*E)+(3*E**7)/4]],
  693. [[5/(8*E**2)+(3*E**6)/8, -3/(8*E**2)+(3*E**6)/8],
  694. [-5/(8*E**2)+(5*E**6)/8, 3/(8*E**2)+(5*E**6)/8]],
  695. [[-3/(2*E)+(5*E)/2, -5/(2*E)+(5*E)/2],
  696. [3/(2*E)-(3*E)/2, 5/(2*E)-(3*E)/2]]
  697. ])
  698. assert_allclose(expm(a), a_res)
  699. def test_readonly(self):
  700. n = 7
  701. a = np.ones((n, n))
  702. a.flags.writeable = False
  703. expm(a)
  704. @pytest.mark.fail_slow(5)
  705. def test_gh18086(self):
  706. A = np.zeros((400, 400), dtype=float)
  707. rng = np.random.default_rng(100)
  708. i = rng.integers(0, 399, 500)
  709. j = rng.integers(0, 399, 500)
  710. A[i, j] = rng.random(500)
  711. # Problem appears when m = 9
  712. Am = np.empty((5, 400, 400), dtype=float)
  713. Am[0] = A.copy()
  714. m, s = pick_pade_structure(Am)
  715. assert m == 9
  716. # Check that result is accurate
  717. first_res = expm(A)
  718. np.testing.assert_array_almost_equal(logm(first_res), A)
  719. # Check that result is consistent
  720. for i in range(5):
  721. next_res = expm(A)
  722. np.testing.assert_array_almost_equal(first_res, next_res)
  723. class TestExpmFrechet:
  724. def test_expm_frechet(self):
  725. # a test of the basic functionality
  726. M = np.array([
  727. [1, 2, 3, 4],
  728. [5, 6, 7, 8],
  729. [0, 0, 1, 2],
  730. [0, 0, 5, 6],
  731. ], dtype=float)
  732. A = np.array([
  733. [1, 2],
  734. [5, 6],
  735. ], dtype=float)
  736. E = np.array([
  737. [3, 4],
  738. [7, 8],
  739. ], dtype=float)
  740. expected_expm = scipy.linalg.expm(A)
  741. expected_frechet = scipy.linalg.expm(M)[:2, 2:]
  742. for kwargs in ({}, {'method':'SPS'}, {'method':'blockEnlarge'}):
  743. observed_expm, observed_frechet = expm_frechet(A, E, **kwargs)
  744. assert_allclose(expected_expm, observed_expm)
  745. assert_allclose(expected_frechet, observed_frechet)
  746. def test_small_norm_expm_frechet(self):
  747. # methodically test matrices with a range of norms, for better coverage
  748. M_original = np.array([
  749. [1, 2, 3, 4],
  750. [5, 6, 7, 8],
  751. [0, 0, 1, 2],
  752. [0, 0, 5, 6],
  753. ], dtype=float)
  754. A_original = np.array([
  755. [1, 2],
  756. [5, 6],
  757. ], dtype=float)
  758. E_original = np.array([
  759. [3, 4],
  760. [7, 8],
  761. ], dtype=float)
  762. A_original_norm_1 = scipy.linalg.norm(A_original, 1)
  763. selected_m_list = [1, 3, 5, 7, 9, 11, 13, 15]
  764. m_neighbor_pairs = zip(selected_m_list[:-1], selected_m_list[1:])
  765. for ma, mb in m_neighbor_pairs:
  766. ell_a = scipy.linalg._expm_frechet.ell_table_61[ma]
  767. ell_b = scipy.linalg._expm_frechet.ell_table_61[mb]
  768. target_norm_1 = 0.5 * (ell_a + ell_b)
  769. scale = target_norm_1 / A_original_norm_1
  770. M = scale * M_original
  771. A = scale * A_original
  772. E = scale * E_original
  773. expected_expm = scipy.linalg.expm(A)
  774. expected_frechet = scipy.linalg.expm(M)[:2, 2:]
  775. observed_expm, observed_frechet = expm_frechet(A, E)
  776. assert_allclose(expected_expm, observed_expm)
  777. assert_allclose(expected_frechet, observed_frechet)
  778. def test_fuzz(self):
  779. rng = np.random.default_rng(1726500908359153)
  780. # try a bunch of crazy inputs
  781. rfuncs = (
  782. rng.uniform,
  783. rng.normal,
  784. rng.standard_cauchy,
  785. rng.exponential)
  786. ntests = 100
  787. for i in range(ntests):
  788. rfunc = rfuncs[rng.choice(4)]
  789. target_norm_1 = rng.exponential()
  790. n = rng.integers(2, 16)
  791. A_original = rfunc(size=(n,n))
  792. E_original = rfunc(size=(n,n))
  793. A_original_norm_1 = scipy.linalg.norm(A_original, 1)
  794. scale = target_norm_1 / A_original_norm_1
  795. A = scale * A_original
  796. E = scale * E_original
  797. M = np.vstack([
  798. np.hstack([A, E]),
  799. np.hstack([np.zeros_like(A), A])])
  800. expected_expm = scipy.linalg.expm(A)
  801. expected_frechet = scipy.linalg.expm(M)[:n, n:]
  802. observed_expm, observed_frechet = expm_frechet(A, E)
  803. assert_allclose(expected_expm, observed_expm, atol=5e-8)
  804. assert_allclose(expected_frechet, observed_frechet, atol=1e-7)
  805. def test_problematic_matrix(self):
  806. # this test case uncovered a bug which has since been fixed
  807. A = np.array([
  808. [1.50591997, 1.93537998],
  809. [0.41203263, 0.23443516],
  810. ], dtype=float)
  811. E = np.array([
  812. [1.87864034, 2.07055038],
  813. [1.34102727, 0.67341123],
  814. ], dtype=float)
  815. scipy.linalg.norm(A, 1)
  816. sps_expm, sps_frechet = expm_frechet(
  817. A, E, method='SPS')
  818. blockEnlarge_expm, blockEnlarge_frechet = expm_frechet(
  819. A, E, method='blockEnlarge')
  820. assert_allclose(sps_expm, blockEnlarge_expm)
  821. assert_allclose(sps_frechet, blockEnlarge_frechet)
  822. @pytest.mark.slow
  823. @pytest.mark.skip(reason='this test is deliberately slow')
  824. def test_medium_matrix(self):
  825. # profile this to see the speed difference
  826. n = 1000
  827. rng = np.random.default_rng(1234)
  828. A = rng.exponential(size=(n, n))
  829. E = rng.exponential(size=(n, n))
  830. sps_expm, sps_frechet = expm_frechet(
  831. A, E, method='SPS')
  832. blockEnlarge_expm, blockEnlarge_frechet = expm_frechet(
  833. A, E, method='blockEnlarge')
  834. assert_allclose(sps_expm, blockEnlarge_expm)
  835. assert_allclose(sps_frechet, blockEnlarge_frechet)
  836. def _help_expm_cond_search(A, A_norm, X, X_norm, eps, p):
  837. p = np.reshape(p, A.shape)
  838. p_norm = norm(p)
  839. perturbation = eps * p * (A_norm / p_norm)
  840. X_prime = expm(A + perturbation)
  841. scaled_relative_error = norm(X_prime - X) / (X_norm * eps)
  842. return -scaled_relative_error
  843. def _normalized_like(A, B):
  844. return A * (scipy.linalg.norm(B) / scipy.linalg.norm(A))
  845. def _relative_error(f, A, perturbation):
  846. X = f(A)
  847. X_prime = f(A + perturbation)
  848. return norm(X_prime - X) / norm(X)
  849. class TestExpmConditionNumber:
  850. def test_expm_cond_smoke(self):
  851. rng = np.random.default_rng(1234)
  852. for n in range(1, 4):
  853. A = rng.standard_normal((n, n))
  854. kappa = expm_cond(A)
  855. assert_array_less(0, kappa)
  856. def test_expm_bad_condition_number(self):
  857. A = np.array([
  858. [-1.128679820, 9.614183771e4, -4.524855739e9, 2.924969411e14],
  859. [0, -1.201010529, 9.634696872e4, -4.681048289e9],
  860. [0, 0, -1.132893222, 9.532491830e4],
  861. [0, 0, 0, -1.179475332],
  862. ])
  863. kappa = expm_cond(A)
  864. assert_array_less(1e36, kappa)
  865. def test_univariate(self):
  866. rng = np.random.default_rng(1234)
  867. for x in np.linspace(-5, 5, num=11):
  868. A = np.array([[x]])
  869. assert_allclose(expm_cond(A), abs(x))
  870. for x in np.logspace(-2, 2, num=11):
  871. A = np.array([[x]])
  872. assert_allclose(expm_cond(A), abs(x))
  873. for i in range(10):
  874. A = rng.standard_normal((1, 1))
  875. assert_allclose(expm_cond(A), np.absolute(A)[0, 0])
  876. @pytest.mark.slow
  877. def test_expm_cond_fuzz(self):
  878. rng = np.random.RandomState(12345)
  879. eps = 1e-5
  880. nsamples = 10
  881. for i in range(nsamples):
  882. n = rng.randint(2, 5)
  883. A = rng.randn(n, n)
  884. A_norm = scipy.linalg.norm(A)
  885. X = expm(A)
  886. X_norm = scipy.linalg.norm(X)
  887. kappa = expm_cond(A)
  888. # Look for the small perturbation that gives the greatest
  889. # relative error.
  890. f = functools.partial(_help_expm_cond_search,
  891. A, A_norm, X, X_norm, eps)
  892. guess = np.ones(n*n)
  893. out = minimize(f, guess, method='L-BFGS-B')
  894. xopt = out.x
  895. yopt = f(xopt)
  896. p_best = eps * _normalized_like(np.reshape(xopt, A.shape), A)
  897. p_best_relerr = _relative_error(expm, A, p_best)
  898. assert_allclose(p_best_relerr, -yopt * eps)
  899. # Check that the identified perturbation indeed gives greater
  900. # relative error than random perturbations with similar norms.
  901. for j in range(5):
  902. p_rand = eps * _normalized_like(rng.randn(*A.shape), A)
  903. assert_allclose(norm(p_best), norm(p_rand))
  904. p_rand_relerr = _relative_error(expm, A, p_rand)
  905. assert_array_less(p_rand_relerr, p_best_relerr)
  906. # The greatest relative error should not be much greater than
  907. # eps times the condition number kappa.
  908. # In the limit as eps approaches zero it should never be greater.
  909. assert_array_less(p_best_relerr, (1 + 2*eps) * eps * kappa)
  910. class TestKhatriRao:
  911. def test_basic(self):
  912. a = khatri_rao(array([[1, 2], [3, 4]]),
  913. array([[5, 6], [7, 8]]))
  914. assert_array_equal(a, array([[5, 12],
  915. [7, 16],
  916. [15, 24],
  917. [21, 32]]))
  918. b = khatri_rao(np.empty([2, 2]), np.empty([2, 2]))
  919. assert_array_equal(b.shape, (4, 2))
  920. def test_number_of_columns_equality(self):
  921. with pytest.raises(ValueError):
  922. a = array([[1, 2, 3],
  923. [4, 5, 6]])
  924. b = array([[1, 2],
  925. [3, 4]])
  926. khatri_rao(a, b)
  927. def test_to_assure_2d_array(self):
  928. with pytest.raises(ValueError):
  929. # both arrays are 1-D
  930. a = array([1, 2, 3])
  931. b = array([4, 5, 6])
  932. khatri_rao(a, b)
  933. with pytest.raises(ValueError):
  934. # first array is 1-D
  935. a = array([1, 2, 3])
  936. b = array([
  937. [1, 2, 3],
  938. [4, 5, 6]
  939. ])
  940. khatri_rao(a, b)
  941. with pytest.raises(ValueError):
  942. # second array is 1-D
  943. a = array([
  944. [1, 2, 3],
  945. [7, 8, 9]
  946. ])
  947. b = array([4, 5, 6])
  948. khatri_rao(a, b)
  949. def test_equality_of_two_equations(self):
  950. a = array([[1, 2], [3, 4]])
  951. b = array([[5, 6], [7, 8]])
  952. res1 = khatri_rao(a, b)
  953. res2 = np.vstack([np.kron(a[:, k], b[:, k])
  954. for k in range(b.shape[1])]).T
  955. assert_array_equal(res1, res2)
  956. def test_empty(self):
  957. a = np.empty((0, 2))
  958. b = np.empty((3, 2))
  959. res = khatri_rao(a, b)
  960. assert_allclose(res, np.empty((0, 2)))
  961. a = np.empty((3, 0))
  962. b = np.empty((5, 0))
  963. res = khatri_rao(a, b)
  964. assert_allclose(res, np.empty((15, 0)))
  965. @pytest.mark.parametrize('func',
  966. [logm, sqrtm, signm])
  967. def test_disp_dep(func):
  968. with pytest.deprecated_call():
  969. func(np.eye(2), disp=False)
  970. def test_blocksize_dep():
  971. with pytest.deprecated_call():
  972. sqrtm(np.eye(2), blocksize=10)