test_decomp.py 118 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189
  1. import itertools
  2. import platform
  3. import sys
  4. import warnings
  5. import numpy as np
  6. from numpy.testing import (assert_equal, assert_almost_equal,
  7. assert_array_almost_equal, assert_array_equal,
  8. assert_, assert_allclose)
  9. import pytest
  10. from pytest import raises as assert_raises
  11. from scipy.linalg import (eig, eigvals, lu, svd, svdvals, cholesky, qr,
  12. schur, rsf2csf, lu_solve, lu_factor, solve, diagsvd,
  13. hessenberg, rq, eig_banded, eigvals_banded, eigh,
  14. eigvalsh, qr_multiply, qz, orth, ordqz,
  15. subspace_angles, hadamard, eigvalsh_tridiagonal,
  16. eigh_tridiagonal, null_space, cdf2rdf, LinAlgError)
  17. from scipy.linalg.lapack import (dgbtrf, dgbtrs, zgbtrf, zgbtrs, dsbev,
  18. dsbevd, dsbevx, zhbevd, zhbevx)
  19. from scipy.linalg._misc import norm
  20. from scipy.linalg._decomp_qz import _select_function
  21. from scipy.stats import ortho_group
  22. from numpy import (array, diag, full, linalg, argsort, zeros, arange,
  23. float32, complex64, ravel, sqrt, iscomplex, shape, sort,
  24. sign, asarray, isfinite, ndarray, eye,)
  25. from scipy.linalg._testutils import assert_no_overwrite
  26. from scipy.sparse._sputils import matrix
  27. from scipy._lib._testutils import check_free_memory
  28. from scipy.linalg.blas import HAS_ILP64
  29. from scipy.conftest import skip_xp_invalid_arg
  30. from scipy.__config__ import CONFIG
  31. IS_WASM = (sys.platform == "emscripten" or platform.machine() in ["wasm32", "wasm64"])
  32. def _random_hermitian_matrix(n, posdef=False, dtype=float):
  33. "Generate random sym/hermitian array of the given size n"
  34. # FIXME non-deterministic rng
  35. if dtype in COMPLEX_DTYPES:
  36. A = np.random.rand(n, n) + np.random.rand(n, n)*1.0j
  37. A = (A + A.conj().T)/2
  38. else:
  39. A = np.random.rand(n, n)
  40. A = (A + A.T)/2
  41. if posdef:
  42. A += sqrt(2*n)*np.eye(n)
  43. return A.astype(dtype)
  44. REAL_DTYPES = [np.float32, np.float64]
  45. COMPLEX_DTYPES = [np.complex64, np.complex128]
  46. DTYPES = REAL_DTYPES + COMPLEX_DTYPES
  47. # XXX: This function should not be defined here, but somewhere in
  48. # scipy.linalg namespace
  49. def symrand(dim_or_eigv, rng):
  50. """Return a random symmetric (Hermitian) matrix.
  51. If 'dim_or_eigv' is an integer N, return a NxN matrix, with eigenvalues
  52. uniformly distributed on (-1,1).
  53. If 'dim_or_eigv' is 1-D real array 'a', return a matrix whose
  54. eigenvalues are 'a'.
  55. """
  56. if isinstance(dim_or_eigv, int):
  57. dim = dim_or_eigv
  58. d = rng.random(dim)*2 - 1
  59. elif (isinstance(dim_or_eigv, ndarray) and
  60. len(dim_or_eigv.shape) == 1):
  61. dim = dim_or_eigv.shape[0]
  62. d = dim_or_eigv
  63. else:
  64. raise TypeError("input type not supported.")
  65. v = ortho_group.rvs(dim)
  66. h = v.T.conj() @ diag(d) @ v
  67. # to avoid roundoff errors, symmetrize the matrix (again)
  68. h = 0.5*(h.T+h)
  69. return h
  70. class TestEigVals:
  71. def test_simple(self):
  72. a = [[1, 2, 3], [1, 2, 3], [2, 5, 6]]
  73. w = eigvals(a)
  74. exact_w = [(9+sqrt(93))/2, 0, (9-sqrt(93))/2]
  75. assert_array_almost_equal(w, exact_w)
  76. def test_simple_tr(self):
  77. a = array([[1, 2, 3], [1, 2, 3], [2, 5, 6]], 'd').T
  78. a = a.copy()
  79. a = a.T
  80. w = eigvals(a)
  81. exact_w = [(9+sqrt(93))/2, 0, (9-sqrt(93))/2]
  82. assert_array_almost_equal(w, exact_w)
  83. def test_simple_complex(self):
  84. a = [[1, 2, 3], [1, 2, 3], [2, 5, 6+1j]]
  85. w = eigvals(a)
  86. exact_w = [(9+1j+sqrt(92+6j))/2,
  87. 0,
  88. (9+1j-sqrt(92+6j))/2]
  89. assert_array_almost_equal(w, exact_w)
  90. def test_finite(self):
  91. a = [[1, 2, 3], [1, 2, 3], [2, 5, 6]]
  92. w = eigvals(a, check_finite=False)
  93. exact_w = [(9+sqrt(93))/2, 0, (9-sqrt(93))/2]
  94. assert_array_almost_equal(w, exact_w)
  95. @pytest.mark.parametrize('dt', [int, float, float32, complex, complex64])
  96. def test_empty(self, dt):
  97. a = np.empty((0, 0), dtype=dt)
  98. w = eigvals(a)
  99. assert w.shape == (0,)
  100. assert w.dtype == eigvals(np.eye(2, dtype=dt)).dtype
  101. w = eigvals(a, homogeneous_eigvals=True)
  102. assert w.shape == (2, 0)
  103. assert w.dtype == eigvals(np.eye(2, dtype=dt)).dtype
  104. class TestEig:
  105. def test_simple(self):
  106. a = array([[1, 2, 3], [1, 2, 3], [2, 5, 6]])
  107. w, v = eig(a)
  108. exact_w = [(9+sqrt(93))/2, 0, (9-sqrt(93))/2]
  109. v0 = array([1, 1, (1+sqrt(93)/3)/2])
  110. v1 = array([3., 0, -1])
  111. v2 = array([1, 1, (1-sqrt(93)/3)/2])
  112. v0 = v0 / norm(v0)
  113. v1 = v1 / norm(v1)
  114. v2 = v2 / norm(v2)
  115. assert_array_almost_equal(w, exact_w)
  116. assert_array_almost_equal(v0, v[:, 0]*sign(v[0, 0]))
  117. assert_array_almost_equal(v1, v[:, 1]*sign(v[0, 1]))
  118. assert_array_almost_equal(v2, v[:, 2]*sign(v[0, 2]))
  119. for i in range(3):
  120. assert_array_almost_equal(a @ v[:, i], w[i]*v[:, i])
  121. w, v = eig(a, left=1, right=0)
  122. for i in range(3):
  123. assert_array_almost_equal(a.T @ v[:, i], w[i]*v[:, i])
  124. def test_simple_complex_eig(self):
  125. a = array([[1, 2], [-2, 1]])
  126. w, vl, vr = eig(a, left=1, right=1)
  127. assert_array_almost_equal(w, array([1+2j, 1-2j]))
  128. for i in range(2):
  129. assert_array_almost_equal(a @ vr[:, i], w[i]*vr[:, i])
  130. for i in range(2):
  131. assert_array_almost_equal(a.conj().T @ vl[:, i],
  132. w[i].conj()*vl[:, i])
  133. def test_simple_complex(self):
  134. a = array([[1, 2, 3], [1, 2, 3], [2, 5, 6+1j]])
  135. w, vl, vr = eig(a, left=1, right=1)
  136. for i in range(3):
  137. assert_array_almost_equal(a @ vr[:, i], w[i]*vr[:, i])
  138. for i in range(3):
  139. assert_array_almost_equal(a.conj().T @ vl[:, i],
  140. w[i].conj()*vl[:, i])
  141. def test_gh_3054(self):
  142. a = [[1]]
  143. b = [[0]]
  144. w, vr = eig(a, b, homogeneous_eigvals=True)
  145. assert_allclose(w[1, 0], 0)
  146. assert_(w[0, 0] != 0)
  147. assert_allclose(vr, 1)
  148. w, vr = eig(a, b)
  149. assert_equal(w, np.inf)
  150. assert_allclose(vr, 1)
  151. def _check_gen_eig(self, A, B, atol_homog=1e-13, rtol_homog=1e-13,
  152. atol=1e-13, rtol=1e-13):
  153. if B is not None:
  154. A, B = asarray(A), asarray(B)
  155. B0 = B
  156. else:
  157. A = asarray(A)
  158. B0 = B
  159. B = np.eye(*A.shape)
  160. msg = f"\n{A!r}\n{B!r}"
  161. # Eigenvalues in homogeneous coordinates
  162. w, vr = eig(A, B0, homogeneous_eigvals=True)
  163. wt = eigvals(A, B0, homogeneous_eigvals=True)
  164. val1 = A @ vr * w[1, :]
  165. val2 = B @ vr * w[0, :]
  166. for i in range(val1.shape[1]):
  167. assert_allclose(val1[:, i], val2[:, i],
  168. rtol=rtol_homog, atol=atol_homog, err_msg=msg)
  169. if B0 is None:
  170. assert_allclose(w[1, :], 1)
  171. assert_allclose(wt[1, :], 1)
  172. perm = np.lexsort(w)
  173. permt = np.lexsort(wt)
  174. assert_allclose(w[:, perm], wt[:, permt], atol=1e-7, rtol=1e-7,
  175. err_msg=msg)
  176. length = np.empty(len(vr))
  177. for i in range(len(vr)):
  178. length[i] = norm(vr[:, i])
  179. assert_allclose(length, np.ones(length.size), err_msg=msg,
  180. atol=1e-7, rtol=1e-7)
  181. # Convert homogeneous coordinates
  182. beta_nonzero = (w[1, :] != 0)
  183. wh = w[0, beta_nonzero] / w[1, beta_nonzero]
  184. # Eigenvalues in standard coordinates
  185. w, vr = eig(A, B0)
  186. wt = eigvals(A, B0)
  187. val1 = A @ vr
  188. val2 = B @ vr * w
  189. res = val1 - val2
  190. for i in range(res.shape[1]):
  191. if np.all(isfinite(res[:, i])):
  192. assert_allclose(res[:, i], 0,
  193. rtol=rtol, atol=atol, err_msg=msg)
  194. # try to consistently order eigenvalues, including complex conjugate pairs
  195. w_fin = w[isfinite(w)]
  196. wt_fin = wt[isfinite(wt)]
  197. # prune noise in the real parts
  198. w_fin = -1j * np.real_if_close(1j*w_fin, tol=1e-10)
  199. wt_fin = -1j * np.real_if_close(1j*wt_fin, tol=1e-10)
  200. perm = argsort(abs(w_fin) + w_fin.imag)
  201. permt = argsort(abs(wt_fin) + wt_fin.imag)
  202. assert_allclose(w_fin[perm], wt_fin[permt],
  203. atol=1e-7, rtol=1e-7, err_msg=msg)
  204. length = np.empty(len(vr))
  205. for i in range(len(vr)):
  206. length[i] = norm(vr[:, i])
  207. assert_allclose(length, np.ones(length.size), err_msg=msg)
  208. # Compare homogeneous and nonhomogeneous versions
  209. assert_allclose(sort(wh), sort(w[np.isfinite(w)]))
  210. def test_singular(self):
  211. # Example taken from
  212. # https://web.archive.org/web/20040903121217/http://www.cs.umu.se/research/nla/singular_pairs/guptri/matlab.html
  213. A = array([[22, 34, 31, 31, 17],
  214. [45, 45, 42, 19, 29],
  215. [39, 47, 49, 26, 34],
  216. [27, 31, 26, 21, 15],
  217. [38, 44, 44, 24, 30]])
  218. B = array([[13, 26, 25, 17, 24],
  219. [31, 46, 40, 26, 37],
  220. [26, 40, 19, 25, 25],
  221. [16, 25, 27, 14, 23],
  222. [24, 35, 18, 21, 22]])
  223. with np.errstate(all='ignore'):
  224. self._check_gen_eig(A, B, atol_homog=5e-13, atol=5e-13)
  225. def test_falker(self):
  226. # Test matrices giving some Nan generalized eigenvalues.
  227. M = diag(array([1, 0, 3]))
  228. K = array(([2, -1, -1], [-1, 2, -1], [-1, -1, 2]))
  229. D = array(([1, -1, 0], [-1, 1, 0], [0, 0, 0]))
  230. Z = zeros((3, 3))
  231. I3 = eye(3)
  232. A = np.block([[I3, Z], [Z, -K]])
  233. B = np.block([[Z, I3], [M, D]])
  234. with np.errstate(all='ignore'):
  235. self._check_gen_eig(A, B)
  236. def test_bad_geneig(self):
  237. # Ticket #709 (strange return values from DGGEV)
  238. def matrices(omega):
  239. c1 = -9 + omega**2
  240. c2 = 2*omega
  241. A = [[1, 0, 0, 0],
  242. [0, 1, 0, 0],
  243. [0, 0, c1, 0],
  244. [0, 0, 0, c1]]
  245. B = [[0, 0, 1, 0],
  246. [0, 0, 0, 1],
  247. [1, 0, 0, -c2],
  248. [0, 1, c2, 0]]
  249. return A, B
  250. # With a buggy LAPACK, this can fail for different omega on different
  251. # machines -- so we need to test several values
  252. with np.errstate(all='ignore'):
  253. for k in range(100):
  254. A, B = matrices(omega=k*5./100)
  255. self._check_gen_eig(A, B)
  256. def test_make_eigvals(self):
  257. # Step through all paths in _make_eigvals
  258. # Real eigenvalues
  259. rng = np.random.RandomState(1234)
  260. A = symrand(3, rng)
  261. self._check_gen_eig(A, None)
  262. B = symrand(3, rng)
  263. self._check_gen_eig(A, B)
  264. # Complex eigenvalues
  265. A = rng.random((3, 3)) + 1j*rng.random((3, 3))
  266. self._check_gen_eig(A, None)
  267. B = rng.random((3, 3)) + 1j*rng.random((3, 3))
  268. self._check_gen_eig(A, B)
  269. def test_check_finite(self):
  270. a = [[1, 2, 3], [1, 2, 3], [2, 5, 6]]
  271. w, v = eig(a, check_finite=False)
  272. exact_w = [(9+sqrt(93))/2, 0, (9-sqrt(93))/2]
  273. v0 = array([1, 1, (1+sqrt(93)/3)/2])
  274. v1 = array([3., 0, -1])
  275. v2 = array([1, 1, (1-sqrt(93)/3)/2])
  276. v0 = v0 / norm(v0)
  277. v1 = v1 / norm(v1)
  278. v2 = v2 / norm(v2)
  279. assert_array_almost_equal(w, exact_w)
  280. assert_array_almost_equal(v0, v[:, 0]*sign(v[0, 0]))
  281. assert_array_almost_equal(v1, v[:, 1]*sign(v[0, 1]))
  282. assert_array_almost_equal(v2, v[:, 2]*sign(v[0, 2]))
  283. for i in range(3):
  284. assert_array_almost_equal(a @ v[:, i], w[i]*v[:, i])
  285. def test_not_square_error(self):
  286. """Check that passing a non-square array raises a ValueError."""
  287. A = np.arange(6).reshape(3, 2)
  288. assert_raises(ValueError, eig, A)
  289. def test_shape_mismatch(self):
  290. """Check that passing arrays of with different shapes
  291. raises a ValueError."""
  292. A = eye(2)
  293. B = np.arange(9.0).reshape(3, 3)
  294. assert_raises(ValueError, eig, A, B)
  295. assert_raises(ValueError, eig, B, A)
  296. def test_gh_11577(self):
  297. # https://github.com/scipy/scipy/issues/11577
  298. # `A - lambda B` should have 4 and 8 among the eigenvalues, and this
  299. # was apparently broken on some platforms
  300. A = np.array([[12.0, 28.0, 76.0, 220.0],
  301. [16.0, 32.0, 80.0, 224.0],
  302. [24.0, 40.0, 88.0, 232.0],
  303. [40.0, 56.0, 104.0, 248.0]], dtype='float64')
  304. B = np.array([[2.0, 4.0, 10.0, 28.0],
  305. [3.0, 5.0, 11.0, 29.0],
  306. [5.0, 7.0, 13.0, 31.0],
  307. [9.0, 11.0, 17.0, 35.0]], dtype='float64')
  308. D, V = eig(A, B)
  309. # The problem is ill-conditioned, and two other eigenvalues
  310. # depend on ATLAS/OpenBLAS version, compiler version etc
  311. # see gh-11577 for discussion
  312. #
  313. # NB: it is tempting to use `assert_allclose(D[:2], [4, 8])` instead but
  314. # the ordering of eigenvalues also comes out different on different
  315. # systems depending on who knows what.
  316. with warnings.catch_warnings():
  317. # isclose chokes on inf/nan values
  318. warnings.filterwarnings(
  319. "ignore", "invalid value encountered in multiply", RuntimeWarning)
  320. assert np.isclose(D, 4.0, atol=1e-14).any()
  321. assert np.isclose(D, 8.0, atol=1e-14).any()
  322. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  323. def test_empty(self, dt):
  324. a = np.empty((0, 0), dtype=dt)
  325. w, vr = eig(a)
  326. w_n, vr_n = eig(np.eye(2, dtype=dt))
  327. assert w.shape == (0,)
  328. assert w.dtype == w_n.dtype #eigvals(np.eye(2, dtype=dt)).dtype
  329. assert_allclose(vr, np.empty((0, 0)))
  330. assert vr.shape == (0, 0)
  331. assert vr.dtype == vr_n.dtype
  332. w, vr = eig(a, homogeneous_eigvals=True)
  333. assert w.shape == (2, 0)
  334. assert w.dtype == w_n.dtype
  335. assert vr.shape == (0, 0)
  336. assert vr.dtype == vr_n.dtype
  337. @pytest.mark.parametrize("include_B", [False, True])
  338. @pytest.mark.parametrize("left", [False, True])
  339. @pytest.mark.parametrize("right", [False, True])
  340. @pytest.mark.parametrize("homogeneous_eigvals", [False, True])
  341. @pytest.mark.parametrize("dtype", [np.float32, np.complex128])
  342. def test_nd_input(self, include_B, left, right, homogeneous_eigvals, dtype):
  343. batch_shape = (3, 2)
  344. core_shape = (4, 4)
  345. rng = np.random.default_rng(3249823598235)
  346. A = rng.random(batch_shape + core_shape).astype(dtype)
  347. B = rng.random(batch_shape + core_shape).astype(dtype)
  348. kwargs = dict(right=right, homogeneous_eigvals=homogeneous_eigvals)
  349. if include_B:
  350. res = eig(A, b=B, left=left, **kwargs)
  351. else:
  352. res = eig(A, left=left, **kwargs)
  353. for i in range(batch_shape[0]):
  354. for j in range(batch_shape[1]):
  355. if include_B:
  356. ref = eig(A[i, j], b=B[i, j], left=left, **kwargs)
  357. else:
  358. ref = eig(A[i, j], left=left, **kwargs)
  359. if left or right:
  360. for k in range(len(ref)):
  361. assert_allclose(res[k][i, j], ref[k])
  362. else:
  363. assert_allclose(res[i, j], ref)
  364. class TestEigBanded:
  365. def setup_method(self):
  366. self.create_bandmat()
  367. def create_bandmat(self):
  368. """Create the full matrix `self.fullmat` and
  369. the corresponding band matrix `self.bandmat`."""
  370. N = 10
  371. self.KL = 2 # number of subdiagonals (below the diagonal)
  372. self.KU = 2 # number of superdiagonals (above the diagonal)
  373. # symmetric band matrix
  374. self.sym_mat = (diag(full(N, 1.0))
  375. + diag(full(N-1, -1.0), -1) + diag(full(N-1, -1.0), 1)
  376. + diag(full(N-2, -2.0), -2) + diag(full(N-2, -2.0), 2))
  377. # hermitian band matrix
  378. self.herm_mat = (diag(full(N, -1.0))
  379. + 1j*diag(full(N-1, 1.0), -1)
  380. - 1j*diag(full(N-1, 1.0), 1)
  381. + diag(full(N-2, -2.0), -2)
  382. + diag(full(N-2, -2.0), 2))
  383. # general real band matrix
  384. self.real_mat = (diag(full(N, 1.0))
  385. + diag(full(N-1, -1.0), -1) + diag(full(N-1, -3.0), 1)
  386. + diag(full(N-2, 2.0), -2) + diag(full(N-2, -2.0), 2))
  387. # general complex band matrix
  388. self.comp_mat = (1j*diag(full(N, 1.0))
  389. + diag(full(N-1, -1.0), -1)
  390. + 1j*diag(full(N-1, -3.0), 1)
  391. + diag(full(N-2, 2.0), -2)
  392. + diag(full(N-2, -2.0), 2))
  393. # Eigenvalues and -vectors from linalg.eig
  394. ew, ev = linalg.eig(self.sym_mat)
  395. ew = ew.real
  396. args = argsort(ew)
  397. self.w_sym_lin = ew[args]
  398. self.evec_sym_lin = ev[:, args]
  399. ew, ev = linalg.eig(self.herm_mat)
  400. ew = ew.real
  401. args = argsort(ew)
  402. self.w_herm_lin = ew[args]
  403. self.evec_herm_lin = ev[:, args]
  404. # Extract upper bands from symmetric and hermitian band matrices
  405. # (for use in dsbevd, dsbevx, zhbevd, zhbevx
  406. # and their single precision versions)
  407. LDAB = self.KU + 1
  408. self.bandmat_sym = zeros((LDAB, N), dtype=float)
  409. self.bandmat_herm = zeros((LDAB, N), dtype=complex)
  410. for i in range(LDAB):
  411. self.bandmat_sym[LDAB-i-1, i:N] = diag(self.sym_mat, i)
  412. self.bandmat_herm[LDAB-i-1, i:N] = diag(self.herm_mat, i)
  413. # Extract bands from general real and complex band matrix
  414. # (for use in dgbtrf, dgbtrs and their single precision versions)
  415. LDAB = 2*self.KL + self.KU + 1
  416. self.bandmat_real = zeros((LDAB, N), dtype=float)
  417. self.bandmat_real[2*self.KL, :] = diag(self.real_mat) # diagonal
  418. for i in range(self.KL):
  419. # superdiagonals
  420. self.bandmat_real[2*self.KL-1-i, i+1:N] = diag(self.real_mat, i+1)
  421. # subdiagonals
  422. self.bandmat_real[2*self.KL+1+i, 0:N-1-i] = diag(self.real_mat,
  423. -i-1)
  424. self.bandmat_comp = zeros((LDAB, N), dtype=complex)
  425. self.bandmat_comp[2*self.KL, :] = diag(self.comp_mat) # diagonal
  426. for i in range(self.KL):
  427. # superdiagonals
  428. self.bandmat_comp[2*self.KL-1-i, i+1:N] = diag(self.comp_mat, i+1)
  429. # subdiagonals
  430. self.bandmat_comp[2*self.KL+1+i, 0:N-1-i] = diag(self.comp_mat,
  431. -i-1)
  432. # absolute value for linear equation system A*x = b
  433. self.b = 1.0*arange(N)
  434. self.bc = self.b * (1 + 1j)
  435. #####################################################################
  436. def test_dsbev(self):
  437. """Compare dsbev eigenvalues and eigenvectors with
  438. the result of linalg.eig."""
  439. w, evec, info = dsbev(self.bandmat_sym, compute_v=1)
  440. evec_ = evec[:, argsort(w)]
  441. assert_array_almost_equal(sort(w), self.w_sym_lin)
  442. assert_array_almost_equal(abs(evec_), abs(self.evec_sym_lin))
  443. def test_dsbevd(self):
  444. """Compare dsbevd eigenvalues and eigenvectors with
  445. the result of linalg.eig."""
  446. w, evec, info = dsbevd(self.bandmat_sym, compute_v=1)
  447. evec_ = evec[:, argsort(w)]
  448. assert_array_almost_equal(sort(w), self.w_sym_lin)
  449. assert_array_almost_equal(abs(evec_), abs(self.evec_sym_lin))
  450. def test_dsbevx(self):
  451. """Compare dsbevx eigenvalues and eigenvectors
  452. with the result of linalg.eig."""
  453. N, N = shape(self.sym_mat)
  454. # Achtung: Argumente 0.0,0.0,range?
  455. w, evec, num, ifail, info = dsbevx(self.bandmat_sym, 0.0, 0.0, 1, N,
  456. compute_v=1, range=2)
  457. evec_ = evec[:, argsort(w)]
  458. assert_array_almost_equal(sort(w), self.w_sym_lin)
  459. assert_array_almost_equal(abs(evec_), abs(self.evec_sym_lin))
  460. def test_zhbevd(self):
  461. """Compare zhbevd eigenvalues and eigenvectors
  462. with the result of linalg.eig."""
  463. w, evec, info = zhbevd(self.bandmat_herm, compute_v=1)
  464. evec_ = evec[:, argsort(w)]
  465. assert_array_almost_equal(sort(w), self.w_herm_lin)
  466. assert_array_almost_equal(abs(evec_), abs(self.evec_herm_lin))
  467. def test_zhbevx(self):
  468. """Compare zhbevx eigenvalues and eigenvectors
  469. with the result of linalg.eig."""
  470. N, N = shape(self.herm_mat)
  471. # Achtung: Argumente 0.0,0.0,range?
  472. w, evec, num, ifail, info = zhbevx(self.bandmat_herm, 0.0, 0.0, 1, N,
  473. compute_v=1, range=2)
  474. evec_ = evec[:, argsort(w)]
  475. assert_array_almost_equal(sort(w), self.w_herm_lin)
  476. assert_array_almost_equal(abs(evec_), abs(self.evec_herm_lin))
  477. def test_eigvals_banded(self):
  478. """Compare eigenvalues of eigvals_banded with those of linalg.eig."""
  479. w_sym = eigvals_banded(self.bandmat_sym)
  480. w_sym = w_sym.real
  481. assert_array_almost_equal(sort(w_sym), self.w_sym_lin)
  482. w_herm = eigvals_banded(self.bandmat_herm)
  483. w_herm = w_herm.real
  484. assert_array_almost_equal(sort(w_herm), self.w_herm_lin)
  485. # extracting eigenvalues with respect to an index range
  486. ind1 = 2
  487. ind2 = np.longlong(6)
  488. w_sym_ind = eigvals_banded(self.bandmat_sym,
  489. select='i', select_range=(ind1, ind2))
  490. assert_array_almost_equal(sort(w_sym_ind),
  491. self.w_sym_lin[ind1:ind2+1])
  492. w_herm_ind = eigvals_banded(self.bandmat_herm,
  493. select='i', select_range=(ind1, ind2))
  494. assert_array_almost_equal(sort(w_herm_ind),
  495. self.w_herm_lin[ind1:ind2+1])
  496. # extracting eigenvalues with respect to a value range
  497. v_lower = self.w_sym_lin[ind1] - 1.0e-5
  498. v_upper = self.w_sym_lin[ind2] + 1.0e-5
  499. w_sym_val = eigvals_banded(self.bandmat_sym,
  500. select='v', select_range=(v_lower, v_upper))
  501. assert_array_almost_equal(sort(w_sym_val),
  502. self.w_sym_lin[ind1:ind2+1])
  503. v_lower = self.w_herm_lin[ind1] - 1.0e-5
  504. v_upper = self.w_herm_lin[ind2] + 1.0e-5
  505. w_herm_val = eigvals_banded(self.bandmat_herm,
  506. select='v',
  507. select_range=(v_lower, v_upper))
  508. assert_array_almost_equal(sort(w_herm_val),
  509. self.w_herm_lin[ind1:ind2+1])
  510. w_sym = eigvals_banded(self.bandmat_sym, check_finite=False)
  511. w_sym = w_sym.real
  512. assert_array_almost_equal(sort(w_sym), self.w_sym_lin)
  513. def test_eig_banded(self):
  514. """Compare eigenvalues and eigenvectors of eig_banded
  515. with those of linalg.eig. """
  516. w_sym, evec_sym = eig_banded(self.bandmat_sym)
  517. evec_sym_ = evec_sym[:, argsort(w_sym.real)]
  518. assert_array_almost_equal(sort(w_sym), self.w_sym_lin)
  519. assert_array_almost_equal(abs(evec_sym_), abs(self.evec_sym_lin))
  520. w_herm, evec_herm = eig_banded(self.bandmat_herm)
  521. evec_herm_ = evec_herm[:, argsort(w_herm.real)]
  522. assert_array_almost_equal(sort(w_herm), self.w_herm_lin)
  523. assert_array_almost_equal(abs(evec_herm_), abs(self.evec_herm_lin))
  524. # extracting eigenvalues with respect to an index range
  525. ind1 = 2
  526. ind2 = 6
  527. w_sym_ind, evec_sym_ind = eig_banded(self.bandmat_sym,
  528. select='i',
  529. select_range=(ind1, ind2))
  530. assert_array_almost_equal(sort(w_sym_ind),
  531. self.w_sym_lin[ind1:ind2+1])
  532. assert_array_almost_equal(abs(evec_sym_ind),
  533. abs(self.evec_sym_lin[:, ind1:ind2+1]))
  534. w_herm_ind, evec_herm_ind = eig_banded(self.bandmat_herm,
  535. select='i',
  536. select_range=(ind1, ind2))
  537. assert_array_almost_equal(sort(w_herm_ind),
  538. self.w_herm_lin[ind1:ind2+1])
  539. assert_array_almost_equal(abs(evec_herm_ind),
  540. abs(self.evec_herm_lin[:, ind1:ind2+1]))
  541. # extracting eigenvalues with respect to a value range
  542. v_lower = self.w_sym_lin[ind1] - 1.0e-5
  543. v_upper = self.w_sym_lin[ind2] + 1.0e-5
  544. w_sym_val, evec_sym_val = eig_banded(self.bandmat_sym,
  545. select='v',
  546. select_range=(v_lower, v_upper))
  547. assert_array_almost_equal(sort(w_sym_val),
  548. self.w_sym_lin[ind1:ind2+1])
  549. assert_array_almost_equal(abs(evec_sym_val),
  550. abs(self.evec_sym_lin[:, ind1:ind2+1]))
  551. v_lower = self.w_herm_lin[ind1] - 1.0e-5
  552. v_upper = self.w_herm_lin[ind2] + 1.0e-5
  553. w_herm_val, evec_herm_val = eig_banded(self.bandmat_herm,
  554. select='v',
  555. select_range=(v_lower, v_upper))
  556. assert_array_almost_equal(sort(w_herm_val),
  557. self.w_herm_lin[ind1:ind2+1])
  558. assert_array_almost_equal(abs(evec_herm_val),
  559. abs(self.evec_herm_lin[:, ind1:ind2+1]))
  560. w_sym, evec_sym = eig_banded(self.bandmat_sym, check_finite=False)
  561. evec_sym_ = evec_sym[:, argsort(w_sym.real)]
  562. assert_array_almost_equal(sort(w_sym), self.w_sym_lin)
  563. assert_array_almost_equal(abs(evec_sym_), abs(self.evec_sym_lin))
  564. def test_dgbtrf(self):
  565. """Compare dgbtrf LU factorisation with the LU factorisation result
  566. of linalg.lu."""
  567. M, N = shape(self.real_mat)
  568. lu_symm_band, ipiv, info = dgbtrf(self.bandmat_real, self.KL, self.KU)
  569. # extract matrix u from lu_symm_band
  570. u = diag(lu_symm_band[2*self.KL, :])
  571. for i in range(self.KL + self.KU):
  572. u += diag(lu_symm_band[2*self.KL-1-i, i+1:N], i+1)
  573. p_lin, l_lin, u_lin = lu(self.real_mat, permute_l=0)
  574. assert_array_almost_equal(u, u_lin)
  575. def test_zgbtrf(self):
  576. """Compare zgbtrf LU factorisation with the LU factorisation result
  577. of linalg.lu."""
  578. M, N = shape(self.comp_mat)
  579. lu_symm_band, ipiv, info = zgbtrf(self.bandmat_comp, self.KL, self.KU)
  580. # extract matrix u from lu_symm_band
  581. u = diag(lu_symm_band[2*self.KL, :])
  582. for i in range(self.KL + self.KU):
  583. u += diag(lu_symm_band[2*self.KL-1-i, i+1:N], i+1)
  584. p_lin, l_lin, u_lin = lu(self.comp_mat, permute_l=0)
  585. assert_array_almost_equal(u, u_lin)
  586. def test_dgbtrs(self):
  587. """Compare dgbtrs solutions for linear equation system A*x = b
  588. with solutions of linalg.solve."""
  589. lu_symm_band, ipiv, info = dgbtrf(self.bandmat_real, self.KL, self.KU)
  590. y, info = dgbtrs(lu_symm_band, self.KL, self.KU, self.b, ipiv)
  591. y_lin = linalg.solve(self.real_mat, self.b)
  592. assert_array_almost_equal(y, y_lin)
  593. def test_zgbtrs(self):
  594. """Compare zgbtrs solutions for linear equation system A*x = b
  595. with solutions of linalg.solve."""
  596. lu_symm_band, ipiv, info = zgbtrf(self.bandmat_comp, self.KL, self.KU)
  597. y, info = zgbtrs(lu_symm_band, self.KL, self.KU, self.bc, ipiv)
  598. y_lin = linalg.solve(self.comp_mat, self.bc)
  599. assert_array_almost_equal(y, y_lin)
  600. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  601. def test_empty(self, dt):
  602. a_band = np.empty((0, 0), dtype=dt)
  603. w, v = eig_banded(a_band)
  604. w_n, v_n = eig_banded(np.array([[0, 0], [1, 1]], dtype=dt))
  605. assert w.shape == (0,)
  606. assert w.dtype == w_n.dtype
  607. assert v.shape == (0, 0)
  608. assert v.dtype == v_n.dtype
  609. w = eig_banded(a_band, eigvals_only=True)
  610. assert w.shape == (0,)
  611. assert w.dtype == w_n.dtype
  612. class TestEigTridiagonal:
  613. def setup_method(self):
  614. self.create_trimat()
  615. def create_trimat(self):
  616. """Create the full matrix `self.fullmat`, `self.d`, and `self.e`."""
  617. N = 10
  618. # symmetric band matrix
  619. self.d = full(N, 1.0)
  620. self.e = full(N-1, -1.0)
  621. self.full_mat = (diag(self.d) + diag(self.e, -1) + diag(self.e, 1))
  622. ew, ev = linalg.eig(self.full_mat)
  623. ew = ew.real
  624. args = argsort(ew)
  625. self.w = ew[args]
  626. self.evec = ev[:, args]
  627. def test_degenerate(self):
  628. """Test error conditions."""
  629. # Wrong sizes
  630. assert_raises(ValueError, eigvalsh_tridiagonal, self.d, self.e[:-1])
  631. # Must be real
  632. assert_raises(TypeError, eigvalsh_tridiagonal, self.d, self.e * 1j)
  633. # Bad driver
  634. assert_raises(TypeError, eigvalsh_tridiagonal, self.d, self.e,
  635. lapack_driver=1.)
  636. assert_raises(ValueError, eigvalsh_tridiagonal, self.d, self.e,
  637. lapack_driver='foo')
  638. # Bad bounds
  639. assert_raises(ValueError, eigvalsh_tridiagonal, self.d, self.e,
  640. select='i', select_range=(0, -1))
  641. def test_eigvalsh_tridiagonal(self):
  642. """Compare eigenvalues of eigvalsh_tridiagonal with those of eig."""
  643. # can't use ?STERF with subselection
  644. for driver in ('sterf', 'stev', 'stevd', 'stebz', 'stemr', 'auto'):
  645. w = eigvalsh_tridiagonal(self.d, self.e, lapack_driver=driver)
  646. assert_array_almost_equal(sort(w), self.w)
  647. for driver in ('sterf', 'stev', 'stevd'):
  648. assert_raises(ValueError, eigvalsh_tridiagonal, self.d, self.e,
  649. lapack_driver=driver, select='i',
  650. select_range=(0, 1))
  651. for driver in ('stebz', 'stemr', 'auto'):
  652. # extracting eigenvalues with respect to the full index range
  653. w_ind = eigvalsh_tridiagonal(
  654. self.d, self.e, select='i', select_range=(0, len(self.d)-1),
  655. lapack_driver=driver)
  656. assert_array_almost_equal(sort(w_ind), self.w)
  657. # extracting eigenvalues with respect to an index range
  658. ind1 = 2
  659. ind2 = 6
  660. w_ind = eigvalsh_tridiagonal(
  661. self.d, self.e, select='i', select_range=(ind1, ind2),
  662. lapack_driver=driver)
  663. assert_array_almost_equal(sort(w_ind), self.w[ind1:ind2+1])
  664. # extracting eigenvalues with respect to a value range
  665. v_lower = self.w[ind1] - 1.0e-5
  666. v_upper = self.w[ind2] + 1.0e-5
  667. w_val = eigvalsh_tridiagonal(
  668. self.d, self.e, select='v', select_range=(v_lower, v_upper),
  669. lapack_driver=driver)
  670. assert_array_almost_equal(sort(w_val), self.w[ind1:ind2+1])
  671. def test_eigh_tridiagonal(self):
  672. """Compare eigenvalues and eigenvectors of eigh_tridiagonal
  673. with those of eig. """
  674. # can't use ?STERF when eigenvectors are requested
  675. assert_raises(ValueError, eigh_tridiagonal, self.d, self.e,
  676. lapack_driver='sterf')
  677. for driver in ('stebz', 'stev', 'stevd', 'stemr', 'auto'):
  678. w, evec = eigh_tridiagonal(self.d, self.e, lapack_driver=driver)
  679. evec_ = evec[:, argsort(w)]
  680. assert_array_almost_equal(sort(w), self.w)
  681. assert_array_almost_equal(abs(evec_), abs(self.evec))
  682. assert_raises(ValueError, eigh_tridiagonal, self.d, self.e,
  683. lapack_driver='stev', select='i', select_range=(0, 1))
  684. for driver in ('stebz', 'stemr', 'auto'):
  685. # extracting eigenvalues with respect to an index range
  686. ind1 = 0
  687. ind2 = len(self.d)-1
  688. w, evec = eigh_tridiagonal(
  689. self.d, self.e, select='i', select_range=(ind1, ind2),
  690. lapack_driver=driver)
  691. assert_array_almost_equal(sort(w), self.w)
  692. assert_array_almost_equal(abs(evec), abs(self.evec))
  693. ind1 = 2
  694. ind2 = 6
  695. w, evec = eigh_tridiagonal(
  696. self.d, self.e, select='i', select_range=(ind1, ind2),
  697. lapack_driver=driver)
  698. assert_array_almost_equal(sort(w), self.w[ind1:ind2+1])
  699. assert_array_almost_equal(abs(evec),
  700. abs(self.evec[:, ind1:ind2+1]))
  701. # extracting eigenvalues with respect to a value range
  702. v_lower = self.w[ind1] - 1.0e-5
  703. v_upper = self.w[ind2] + 1.0e-5
  704. w, evec = eigh_tridiagonal(
  705. self.d, self.e, select='v', select_range=(v_lower, v_upper),
  706. lapack_driver=driver)
  707. assert_array_almost_equal(sort(w), self.w[ind1:ind2+1])
  708. assert_array_almost_equal(abs(evec),
  709. abs(self.evec[:, ind1:ind2+1]))
  710. def test_eigh_tridiagonal_1x1(self):
  711. """See gh-20075"""
  712. a = np.array([-2.0])
  713. b = np.array([])
  714. x = eigh_tridiagonal(a, b, eigvals_only=True)
  715. assert x.ndim == 1
  716. assert_allclose(x, a)
  717. x, V = eigh_tridiagonal(a, b, select="i", select_range=(0, 0))
  718. assert x.ndim == 1
  719. assert V.ndim == 2
  720. assert_allclose(x, a)
  721. assert_allclose(V, array([[1.]]))
  722. x, V = eigh_tridiagonal(a, b, select="v", select_range=(-2, 0))
  723. assert x.size == 0
  724. assert x.shape == (0,)
  725. assert V.shape == (1, 0)
  726. class TestEigh:
  727. def test_wrong_inputs(self):
  728. # Nonsquare a
  729. assert_raises(ValueError, eigh, np.ones([1, 2]))
  730. # Nonsquare b
  731. assert_raises(ValueError, eigh, np.ones([2, 2]), np.ones([2, 1]))
  732. # Incompatible a, b sizes
  733. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([2, 2]))
  734. # Wrong type parameter for generalized problem
  735. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  736. type=4)
  737. # Both value and index subsets requested
  738. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  739. subset_by_value=[1, 2], subset_by_index=[2, 4])
  740. # Invalid upper index spec
  741. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  742. subset_by_index=[0, 4])
  743. # Invalid lower index
  744. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  745. subset_by_index=[-2, 2])
  746. # Invalid index spec #2
  747. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  748. subset_by_index=[2, 0])
  749. # Invalid value spec
  750. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  751. subset_by_value=[2, 0])
  752. # Invalid driver name
  753. assert_raises(ValueError, eigh, np.ones([2, 2]), driver='wrong')
  754. # Generalized driver selection without b
  755. assert_raises(ValueError, eigh, np.ones([3, 3]), None, driver='gvx')
  756. # Standard driver with b
  757. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  758. driver='evr')
  759. # Subset request from invalid driver
  760. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  761. driver='gvd', subset_by_index=[1, 2])
  762. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  763. driver='gvd', subset_by_index=[1, 2])
  764. def test_nonpositive_b(self):
  765. assert_raises(LinAlgError, eigh, np.ones([3, 3]), np.ones([3, 3]))
  766. # index based subsets are done in the legacy test_eigh()
  767. def test_value_subsets(self):
  768. for ind, dt in enumerate(DTYPES):
  769. a = _random_hermitian_matrix(20, dtype=dt)
  770. w, v = eigh(a, subset_by_value=[-2, 2])
  771. assert_equal(v.shape[1], len(w))
  772. assert all((w > -2) & (w < 2))
  773. b = _random_hermitian_matrix(20, posdef=True, dtype=dt)
  774. w, v = eigh(a, b, subset_by_value=[-2, 2])
  775. assert_equal(v.shape[1], len(w))
  776. assert all((w > -2) & (w < 2))
  777. def test_eigh_integer(self):
  778. a = array([[1, 2], [2, 7]])
  779. b = array([[3, 1], [1, 5]])
  780. w, z = eigh(a)
  781. w, z = eigh(a, b)
  782. @skip_xp_invalid_arg
  783. def test_eigh_of_sparse(self):
  784. # This tests the rejection of inputs that eigh cannot currently handle.
  785. import scipy.sparse
  786. a = scipy.sparse.identity(2).tocsc()
  787. b = np.atleast_2d(a)
  788. assert_raises(ValueError, eigh, a)
  789. assert_raises(ValueError, eigh, b)
  790. @pytest.mark.parametrize('dtype_', DTYPES)
  791. @pytest.mark.parametrize('driver', ("ev", "evd", "evr", "evx"))
  792. def test_various_drivers_standard(self, driver, dtype_):
  793. a = _random_hermitian_matrix(n=20, dtype=dtype_)
  794. w, v = eigh(a, driver=driver)
  795. assert_allclose(a @ v - (v * w), 0.,
  796. atol=1000*np.finfo(dtype_).eps,
  797. rtol=0.)
  798. @pytest.mark.parametrize('driver', ("ev", "evd", "evr", "evx"))
  799. def test_1x1_lwork(self, driver):
  800. w, v = eigh([[1]], driver=driver)
  801. assert_allclose(w, array([1.]), atol=1e-15)
  802. assert_allclose(v, array([[1.]]), atol=1e-15)
  803. # complex case now
  804. w, v = eigh([[1j]], driver=driver)
  805. assert_allclose(w, array([0]), atol=1e-15)
  806. assert_allclose(v, array([[1.]]), atol=1e-15)
  807. @pytest.mark.parametrize('type', (1, 2, 3))
  808. @pytest.mark.parametrize('driver', ("gv", "gvd", "gvx"))
  809. def test_various_drivers_generalized(self, driver, type):
  810. atol = np.spacing(5000.)
  811. a = _random_hermitian_matrix(20)
  812. b = _random_hermitian_matrix(20, posdef=True)
  813. w, v = eigh(a=a, b=b, driver=driver, type=type)
  814. if type == 1:
  815. assert_allclose(a @ v - w*(b @ v), 0., atol=atol, rtol=0.)
  816. elif type == 2:
  817. assert_allclose(a @ b @ v - v * w, 0., atol=atol, rtol=0.)
  818. else:
  819. assert_allclose(b @ a @ v - v * w, 0., atol=atol, rtol=0.)
  820. def test_eigvalsh_new_args(self):
  821. a = _random_hermitian_matrix(5)
  822. w = eigvalsh(a, subset_by_index=[1, 2])
  823. assert_equal(len(w), 2)
  824. w2 = eigvalsh(a, subset_by_index=[1, 2])
  825. assert_equal(len(w2), 2)
  826. assert_allclose(w, w2)
  827. b = np.diag([1, 1.2, 1.3, 1.5, 2])
  828. w3 = eigvalsh(b, subset_by_value=[1, 1.4])
  829. assert_equal(len(w3), 2)
  830. assert_allclose(w3, np.array([1.2, 1.3]))
  831. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  832. def test_empty(self, dt):
  833. a = np.empty((0, 0), dtype=dt)
  834. w, v = eigh(a)
  835. w_n, v_n = eigh(np.eye(2, dtype=dt))
  836. assert w.shape == (0,)
  837. assert w.dtype == w_n.dtype
  838. assert v.shape == (0, 0)
  839. assert v.dtype == v_n.dtype
  840. w = eigh(a, eigvals_only=True)
  841. assert_allclose(w, np.empty((0,)))
  842. assert w.shape == (0,)
  843. assert w.dtype == w_n.dtype
  844. class TestSVD_GESDD:
  845. lapack_driver = 'gesdd'
  846. def test_degenerate(self):
  847. assert_raises(TypeError, svd, [[1.]], lapack_driver=1.)
  848. assert_raises(ValueError, svd, [[1.]], lapack_driver='foo')
  849. def test_simple(self):
  850. a = [[1, 2, 3], [1, 20, 3], [2, 5, 6]]
  851. for full_matrices in (True, False):
  852. u, s, vh = svd(a, full_matrices=full_matrices,
  853. lapack_driver=self.lapack_driver)
  854. assert_array_almost_equal(u.T @ u, eye(3))
  855. assert_array_almost_equal(vh.T @ vh, eye(3))
  856. sigma = zeros((u.shape[0], vh.shape[0]), s.dtype.char)
  857. for i in range(len(s)):
  858. sigma[i, i] = s[i]
  859. assert_array_almost_equal(u @ sigma @ vh, a)
  860. def test_simple_singular(self):
  861. a = [[1, 2, 3], [1, 2, 3], [2, 5, 6]]
  862. for full_matrices in (True, False):
  863. u, s, vh = svd(a, full_matrices=full_matrices,
  864. lapack_driver=self.lapack_driver)
  865. assert_array_almost_equal(u.T @ u, eye(3))
  866. assert_array_almost_equal(vh.T @ vh, eye(3))
  867. sigma = zeros((u.shape[0], vh.shape[0]), s.dtype.char)
  868. for i in range(len(s)):
  869. sigma[i, i] = s[i]
  870. assert_array_almost_equal(u @ sigma @ vh, a)
  871. def test_simple_underdet(self):
  872. a = [[1, 2, 3], [4, 5, 6]]
  873. for full_matrices in (True, False):
  874. u, s, vh = svd(a, full_matrices=full_matrices,
  875. lapack_driver=self.lapack_driver)
  876. assert_array_almost_equal(u.T @ u, eye(u.shape[0]))
  877. sigma = zeros((u.shape[0], vh.shape[0]), s.dtype.char)
  878. for i in range(len(s)):
  879. sigma[i, i] = s[i]
  880. assert_array_almost_equal(u @ sigma @ vh, a)
  881. def test_simple_overdet(self):
  882. a = [[1, 2], [4, 5], [3, 4]]
  883. for full_matrices in (True, False):
  884. u, s, vh = svd(a, full_matrices=full_matrices,
  885. lapack_driver=self.lapack_driver)
  886. assert_array_almost_equal(u.T @ u, eye(u.shape[1]))
  887. assert_array_almost_equal(vh.T @ vh, eye(2))
  888. sigma = zeros((u.shape[1], vh.shape[0]), s.dtype.char)
  889. for i in range(len(s)):
  890. sigma[i, i] = s[i]
  891. assert_array_almost_equal(u @ sigma @ vh, a)
  892. def test_random(self):
  893. rng = np.random.RandomState(1234)
  894. n = 20
  895. m = 15
  896. for i in range(3):
  897. for a in [rng.random([n, m]), rng.random([m, n])]:
  898. for full_matrices in (True, False):
  899. u, s, vh = svd(a, full_matrices=full_matrices,
  900. lapack_driver=self.lapack_driver)
  901. assert_array_almost_equal(u.T @ u, eye(u.shape[1]))
  902. assert_array_almost_equal(vh @ vh.T, eye(vh.shape[0]))
  903. sigma = zeros((u.shape[1], vh.shape[0]), s.dtype.char)
  904. for i in range(len(s)):
  905. sigma[i, i] = s[i]
  906. assert_array_almost_equal(u @ sigma @ vh, a)
  907. def test_simple_complex(self):
  908. a = [[1, 2, 3], [1, 2j, 3], [2, 5, 6]]
  909. for full_matrices in (True, False):
  910. u, s, vh = svd(a, full_matrices=full_matrices,
  911. lapack_driver=self.lapack_driver)
  912. assert_array_almost_equal(u.conj().T @ u, eye(u.shape[1]))
  913. assert_array_almost_equal(vh.conj().T @ vh, eye(vh.shape[0]))
  914. sigma = zeros((u.shape[0], vh.shape[0]), s.dtype.char)
  915. for i in range(len(s)):
  916. sigma[i, i] = s[i]
  917. assert_array_almost_equal(u @ sigma @ vh, a)
  918. def test_random_complex(self):
  919. rng = np.random.RandomState(1234)
  920. n = 20
  921. m = 15
  922. for i in range(3):
  923. for full_matrices in (True, False):
  924. for a in [rng.random([n, m]), rng.random([m, n])]:
  925. a = a + 1j*rng.random(list(a.shape))
  926. u, s, vh = svd(a, full_matrices=full_matrices,
  927. lapack_driver=self.lapack_driver)
  928. assert_array_almost_equal(u.conj().T @ u,
  929. eye(u.shape[1]))
  930. # This fails when [m,n]
  931. # assert_array_almost_equal(vh.conj().T @ vh,
  932. # eye(len(vh),dtype=vh.dtype.char))
  933. sigma = zeros((u.shape[1], vh.shape[0]), s.dtype.char)
  934. for i in range(len(s)):
  935. sigma[i, i] = s[i]
  936. assert_array_almost_equal(u @ sigma @ vh, a)
  937. def test_crash_1580(self):
  938. rng = np.random.RandomState(1234)
  939. sizes = [(13, 23), (30, 50), (60, 100)]
  940. for sz in sizes:
  941. for dt in [np.float32, np.float64, np.complex64, np.complex128]:
  942. a = rng.rand(*sz).astype(dt)
  943. # should not crash
  944. svd(a, lapack_driver=self.lapack_driver)
  945. def test_check_finite(self):
  946. a = [[1, 2, 3], [1, 20, 3], [2, 5, 6]]
  947. u, s, vh = svd(a, check_finite=False, lapack_driver=self.lapack_driver)
  948. assert_array_almost_equal(u.T @ u, eye(3))
  949. assert_array_almost_equal(vh.T @ vh, eye(3))
  950. sigma = zeros((u.shape[0], vh.shape[0]), s.dtype.char)
  951. for i in range(len(s)):
  952. sigma[i, i] = s[i]
  953. assert_array_almost_equal(u @ sigma @ vh, a)
  954. def test_gh_5039(self):
  955. # This is a smoke test for https://github.com/scipy/scipy/issues/5039
  956. #
  957. # The following is reported to raise "ValueError: On entry to DGESDD
  958. # parameter number 12 had an illegal value".
  959. # `interp1d([1,2,3,4], [1,2,3,4], kind='cubic')`
  960. # This is reported to only show up on LAPACK 3.0.3.
  961. #
  962. # The matrix below is taken from the call to
  963. # `B = _fitpack._bsplmat(order, xk)` in interpolate._find_smoothest
  964. b = np.array(
  965. [[0.16666667, 0.66666667, 0.16666667, 0., 0., 0.],
  966. [0., 0.16666667, 0.66666667, 0.16666667, 0., 0.],
  967. [0., 0., 0.16666667, 0.66666667, 0.16666667, 0.],
  968. [0., 0., 0., 0.16666667, 0.66666667, 0.16666667]])
  969. svd(b, lapack_driver=self.lapack_driver)
  970. @pytest.mark.skipif(not HAS_ILP64, reason="64-bit LAPACK required")
  971. @pytest.mark.slow
  972. def test_large_matrix(self):
  973. check_free_memory(free_mb=17000)
  974. A = np.zeros([1, 2**31], dtype=np.float32)
  975. A[0, -1] = 1
  976. u, s, vh = svd(A, full_matrices=False)
  977. assert_allclose(s[0], 1.0)
  978. assert_allclose(u[0, 0] * vh[0, -1], 1.0)
  979. @pytest.mark.parametrize("m", [0, 1, 2])
  980. @pytest.mark.parametrize("n", [0, 1, 2])
  981. @pytest.mark.parametrize('dtype', DTYPES)
  982. def test_shape_dtype(self, m, n, dtype):
  983. a = np.zeros((m, n), dtype=dtype)
  984. k = min(m, n)
  985. dchar = a.dtype.char
  986. real_dchar = dchar.lower() if dchar in 'FD' else dchar
  987. u, s, v = svd(a)
  988. assert_equal(u.shape, (m, m))
  989. assert_equal(u.dtype, dtype)
  990. assert_equal(s.shape, (k,))
  991. assert_equal(s.dtype, np.dtype(real_dchar))
  992. assert_equal(v.shape, (n, n))
  993. assert_equal(v.dtype, dtype)
  994. u, s, v = svd(a, full_matrices=False)
  995. assert_equal(u.shape, (m, k))
  996. assert_equal(u.dtype, dtype)
  997. assert_equal(s.shape, (k,))
  998. assert_equal(s.dtype, np.dtype(real_dchar))
  999. assert_equal(v.shape, (k, n))
  1000. assert_equal(v.dtype, dtype)
  1001. s = svd(a, compute_uv=False)
  1002. assert_equal(s.shape, (k,))
  1003. assert_equal(s.dtype, np.dtype(real_dchar))
  1004. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  1005. @pytest.mark.parametrize(("m", "n"), [(0, 0), (0, 2), (2, 0)])
  1006. def test_empty(self, dt, m, n):
  1007. a0 = np.eye(3, dtype=dt)
  1008. u0, s0, v0 = svd(a0)
  1009. a = np.empty((m, n), dtype=dt)
  1010. u, s, v = svd(a)
  1011. assert_allclose(u, np.identity(m))
  1012. assert_allclose(s, np.empty((0,)))
  1013. assert_allclose(v, np.identity(n))
  1014. assert u.dtype == u0.dtype
  1015. assert v.dtype == v0.dtype
  1016. assert s.dtype == s0.dtype
  1017. u, s, v = svd(a, full_matrices=False)
  1018. assert_allclose(u, np.empty((m, 0)))
  1019. assert_allclose(s, np.empty((0,)))
  1020. assert_allclose(v, np.empty((0, n)))
  1021. assert u.dtype == u0.dtype
  1022. assert v.dtype == v0.dtype
  1023. assert s.dtype == s0.dtype
  1024. s = svd(a, compute_uv=False)
  1025. assert_allclose(s, np.empty((0,)))
  1026. assert s.dtype == s0.dtype
  1027. class TestSVD_GESVD(TestSVD_GESDD):
  1028. lapack_driver = 'gesvd'
  1029. # Allocating an array of such a size leads to _ArrayMemoryError(s)
  1030. # since the maximum memory that can be in 32-bit (WASM) is 4GB
  1031. @pytest.mark.skipif(IS_WASM, reason="out of memory in WASM")
  1032. @pytest.mark.xfail_on_32bit("out of memory in 32-bit CI workflow")
  1033. @pytest.mark.parallel_threads_limit(2) # 1.9 GiB per thread RAM usage
  1034. @pytest.mark.fail_slow(10)
  1035. def test_svd_gesdd_nofegfault():
  1036. # svd(a) with {U,VT}.size > INT_MAX does not segfault
  1037. # cf https://github.com/scipy/scipy/issues/14001
  1038. df=np.ones((4799, 53130), dtype=np.float64)
  1039. with assert_raises(ValueError):
  1040. svd(df)
  1041. def test_gesdd_nan_error_message():
  1042. A = np.eye(2)
  1043. A[0, 0] = np.nan
  1044. with pytest.raises(ValueError, match="NaN"):
  1045. svd(A, check_finite=False)
  1046. class TestSVDVals:
  1047. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  1048. def test_empty(self, dt):
  1049. for a in [[]], np.empty((2, 0)), np.ones((0, 3)):
  1050. a = np.array(a, dtype=dt)
  1051. s = svdvals(a)
  1052. assert_equal(s, np.empty(0))
  1053. s0 = svdvals(np.eye(2, dtype=dt))
  1054. assert s.dtype == s0.dtype
  1055. def test_simple(self):
  1056. a = [[1, 2, 3], [1, 2, 3], [2, 5, 6]]
  1057. s = svdvals(a)
  1058. assert_(len(s) == 3)
  1059. assert_(s[0] >= s[1] >= s[2])
  1060. def test_simple_underdet(self):
  1061. a = [[1, 2, 3], [4, 5, 6]]
  1062. s = svdvals(a)
  1063. assert_(len(s) == 2)
  1064. assert_(s[0] >= s[1])
  1065. def test_simple_overdet(self):
  1066. a = [[1, 2], [4, 5], [3, 4]]
  1067. s = svdvals(a)
  1068. assert_(len(s) == 2)
  1069. assert_(s[0] >= s[1])
  1070. def test_simple_complex(self):
  1071. a = [[1, 2, 3], [1, 20, 3j], [2, 5, 6]]
  1072. s = svdvals(a)
  1073. assert_(len(s) == 3)
  1074. assert_(s[0] >= s[1] >= s[2])
  1075. def test_simple_underdet_complex(self):
  1076. a = [[1, 2, 3], [4, 5j, 6]]
  1077. s = svdvals(a)
  1078. assert_(len(s) == 2)
  1079. assert_(s[0] >= s[1])
  1080. def test_simple_overdet_complex(self):
  1081. a = [[1, 2], [4, 5], [3j, 4]]
  1082. s = svdvals(a)
  1083. assert_(len(s) == 2)
  1084. assert_(s[0] >= s[1])
  1085. def test_check_finite(self):
  1086. a = [[1, 2, 3], [1, 2, 3], [2, 5, 6]]
  1087. s = svdvals(a, check_finite=False)
  1088. assert_(len(s) == 3)
  1089. assert_(s[0] >= s[1] >= s[2])
  1090. @pytest.mark.slow
  1091. def test_crash_2609(self):
  1092. rng = np.random.default_rng(1234)
  1093. a = rng.random((1500, 2800))
  1094. # Shouldn't crash:
  1095. svdvals(a)
  1096. class TestDiagSVD:
  1097. def test_simple(self):
  1098. assert_array_almost_equal(diagsvd([1, 0, 0], 3, 3),
  1099. [[1, 0, 0], [0, 0, 0], [0, 0, 0]])
  1100. class TestQR:
  1101. def test_simple(self):
  1102. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1103. q, r = qr(a)
  1104. assert_array_almost_equal(q.T @ q, eye(3))
  1105. assert_array_almost_equal(q @ r, a)
  1106. def test_simple_left(self):
  1107. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1108. q, r = qr(a)
  1109. c = [1, 2, 3]
  1110. qc, r2 = qr_multiply(a, c, "left")
  1111. assert_array_almost_equal(q @ c, qc)
  1112. assert_array_almost_equal(r, r2)
  1113. qc, r2 = qr_multiply(a, eye(3), "left")
  1114. assert_array_almost_equal(q, qc)
  1115. def test_simple_right(self):
  1116. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1117. q, r = qr(a)
  1118. c = [1, 2, 3]
  1119. qc, r2 = qr_multiply(a, c)
  1120. assert_array_almost_equal(c @ q, qc)
  1121. assert_array_almost_equal(r, r2)
  1122. qc, r = qr_multiply(a, eye(3))
  1123. assert_array_almost_equal(q, qc)
  1124. def test_simple_pivoting(self):
  1125. a = np.asarray([[8, 2, 3], [2, 9, 3], [5, 3, 6]])
  1126. q, r, p = qr(a, pivoting=True)
  1127. d = abs(diag(r))
  1128. assert_(np.all(d[1:] <= d[:-1]))
  1129. assert_array_almost_equal(q.T @ q, eye(3))
  1130. assert_array_almost_equal(q @ r, a[:, p])
  1131. q2, r2 = qr(a[:, p])
  1132. assert_array_almost_equal(q, q2)
  1133. assert_array_almost_equal(r, r2)
  1134. def test_simple_left_pivoting(self):
  1135. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1136. q, r, jpvt = qr(a, pivoting=True)
  1137. c = [1, 2, 3]
  1138. qc, r, jpvt = qr_multiply(a, c, "left", True)
  1139. assert_array_almost_equal(q @ c, qc)
  1140. def test_simple_right_pivoting(self):
  1141. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1142. q, r, jpvt = qr(a, pivoting=True)
  1143. c = [1, 2, 3]
  1144. qc, r, jpvt = qr_multiply(a, c, pivoting=True)
  1145. assert_array_almost_equal(c @ q, qc)
  1146. def test_simple_trap(self):
  1147. a = [[8, 2, 3], [2, 9, 3]]
  1148. q, r = qr(a)
  1149. assert_array_almost_equal(q.T @ q, eye(2))
  1150. assert_array_almost_equal(q @ r, a)
  1151. def test_simple_trap_pivoting(self):
  1152. a = np.asarray([[8, 2, 3], [2, 9, 3]])
  1153. q, r, p = qr(a, pivoting=True)
  1154. d = abs(diag(r))
  1155. assert_(np.all(d[1:] <= d[:-1]))
  1156. assert_array_almost_equal(q.T @ q, eye(2))
  1157. assert_array_almost_equal(q @ r, a[:, p])
  1158. q2, r2 = qr(a[:, p])
  1159. assert_array_almost_equal(q, q2)
  1160. assert_array_almost_equal(r, r2)
  1161. def test_simple_tall(self):
  1162. # full version
  1163. a = [[8, 2], [2, 9], [5, 3]]
  1164. q, r = qr(a)
  1165. assert_array_almost_equal(q.T @ q, eye(3))
  1166. assert_array_almost_equal(q @ r, a)
  1167. def test_simple_tall_pivoting(self):
  1168. # full version pivoting
  1169. a = np.asarray([[8, 2], [2, 9], [5, 3]])
  1170. q, r, p = qr(a, pivoting=True)
  1171. d = abs(diag(r))
  1172. assert_(np.all(d[1:] <= d[:-1]))
  1173. assert_array_almost_equal(q.T @ q, eye(3))
  1174. assert_array_almost_equal(q @ r, a[:, p])
  1175. q2, r2 = qr(a[:, p])
  1176. assert_array_almost_equal(q, q2)
  1177. assert_array_almost_equal(r, r2)
  1178. def test_simple_tall_e(self):
  1179. # economy version
  1180. a = [[8, 2], [2, 9], [5, 3]]
  1181. q, r = qr(a, mode='economic')
  1182. assert_array_almost_equal(q.T @ q, eye(2))
  1183. assert_array_almost_equal(q @ r, a)
  1184. assert_equal(q.shape, (3, 2))
  1185. assert_equal(r.shape, (2, 2))
  1186. def test_simple_tall_e_pivoting(self):
  1187. # economy version pivoting
  1188. a = np.asarray([[8, 2], [2, 9], [5, 3]])
  1189. q, r, p = qr(a, pivoting=True, mode='economic')
  1190. d = abs(diag(r))
  1191. assert_(np.all(d[1:] <= d[:-1]))
  1192. assert_array_almost_equal(q.T @ q, eye(2))
  1193. assert_array_almost_equal(q @ r, a[:, p])
  1194. q2, r2 = qr(a[:, p], mode='economic')
  1195. assert_array_almost_equal(q, q2)
  1196. assert_array_almost_equal(r, r2)
  1197. def test_simple_tall_left(self):
  1198. a = [[8, 2], [2, 9], [5, 3]]
  1199. q, r = qr(a, mode="economic")
  1200. c = [1, 2]
  1201. qc, r2 = qr_multiply(a, c, "left")
  1202. assert_array_almost_equal(q @ c, qc)
  1203. assert_array_almost_equal(r, r2)
  1204. c = array([1, 2, 0])
  1205. qc, r2 = qr_multiply(a, c, "left", overwrite_c=True)
  1206. assert_array_almost_equal(q @ c[:2], qc)
  1207. qc, r = qr_multiply(a, eye(2), "left")
  1208. assert_array_almost_equal(qc, q)
  1209. def test_simple_tall_left_pivoting(self):
  1210. a = [[8, 2], [2, 9], [5, 3]]
  1211. q, r, jpvt = qr(a, mode="economic", pivoting=True)
  1212. c = [1, 2]
  1213. qc, r, kpvt = qr_multiply(a, c, "left", True)
  1214. assert_array_equal(jpvt, kpvt)
  1215. assert_array_almost_equal(q @ c, qc)
  1216. qc, r, jpvt = qr_multiply(a, eye(2), "left", True)
  1217. assert_array_almost_equal(qc, q)
  1218. def test_simple_tall_right(self):
  1219. a = [[8, 2], [2, 9], [5, 3]]
  1220. q, r = qr(a, mode="economic")
  1221. c = [1, 2, 3]
  1222. cq, r2 = qr_multiply(a, c)
  1223. assert_array_almost_equal(c @ q, cq)
  1224. assert_array_almost_equal(r, r2)
  1225. cq, r = qr_multiply(a, eye(3))
  1226. assert_array_almost_equal(cq, q)
  1227. def test_simple_tall_right_pivoting(self):
  1228. a = [[8, 2], [2, 9], [5, 3]]
  1229. q, r, jpvt = qr(a, pivoting=True, mode="economic")
  1230. c = [1, 2, 3]
  1231. cq, r, jpvt = qr_multiply(a, c, pivoting=True)
  1232. assert_array_almost_equal(c @ q, cq)
  1233. cq, r, jpvt = qr_multiply(a, eye(3), pivoting=True)
  1234. assert_array_almost_equal(cq, q)
  1235. def test_simple_fat(self):
  1236. # full version
  1237. a = [[8, 2, 5], [2, 9, 3]]
  1238. q, r = qr(a)
  1239. assert_array_almost_equal(q.T @ q, eye(2))
  1240. assert_array_almost_equal(q @ r, a)
  1241. assert_equal(q.shape, (2, 2))
  1242. assert_equal(r.shape, (2, 3))
  1243. def test_simple_fat_pivoting(self):
  1244. # full version pivoting
  1245. a = np.asarray([[8, 2, 5], [2, 9, 3]])
  1246. q, r, p = qr(a, pivoting=True)
  1247. d = abs(diag(r))
  1248. assert_(np.all(d[1:] <= d[:-1]))
  1249. assert_array_almost_equal(q.T @ q, eye(2))
  1250. assert_array_almost_equal(q @ r, a[:, p])
  1251. assert_equal(q.shape, (2, 2))
  1252. assert_equal(r.shape, (2, 3))
  1253. q2, r2 = qr(a[:, p])
  1254. assert_array_almost_equal(q, q2)
  1255. assert_array_almost_equal(r, r2)
  1256. def test_simple_fat_e(self):
  1257. # economy version
  1258. a = [[8, 2, 3], [2, 9, 5]]
  1259. q, r = qr(a, mode='economic')
  1260. assert_array_almost_equal(q.T @ q, eye(2))
  1261. assert_array_almost_equal(q @ r, a)
  1262. assert_equal(q.shape, (2, 2))
  1263. assert_equal(r.shape, (2, 3))
  1264. def test_simple_fat_e_pivoting(self):
  1265. # economy version pivoting
  1266. a = np.asarray([[8, 2, 3], [2, 9, 5]])
  1267. q, r, p = qr(a, pivoting=True, mode='economic')
  1268. d = abs(diag(r))
  1269. assert_(np.all(d[1:] <= d[:-1]))
  1270. assert_array_almost_equal(q.T @ q, eye(2))
  1271. assert_array_almost_equal(q @ r, a[:, p])
  1272. assert_equal(q.shape, (2, 2))
  1273. assert_equal(r.shape, (2, 3))
  1274. q2, r2 = qr(a[:, p], mode='economic')
  1275. assert_array_almost_equal(q, q2)
  1276. assert_array_almost_equal(r, r2)
  1277. def test_simple_fat_left(self):
  1278. a = [[8, 2, 3], [2, 9, 5]]
  1279. q, r = qr(a, mode="economic")
  1280. c = [1, 2]
  1281. qc, r2 = qr_multiply(a, c, "left")
  1282. assert_array_almost_equal(q @ c, qc)
  1283. assert_array_almost_equal(r, r2)
  1284. qc, r = qr_multiply(a, eye(2), "left")
  1285. assert_array_almost_equal(qc, q)
  1286. def test_simple_fat_left_pivoting(self):
  1287. a = [[8, 2, 3], [2, 9, 5]]
  1288. q, r, jpvt = qr(a, mode="economic", pivoting=True)
  1289. c = [1, 2]
  1290. qc, r, jpvt = qr_multiply(a, c, "left", True)
  1291. assert_array_almost_equal(q @ c, qc)
  1292. qc, r, jpvt = qr_multiply(a, eye(2), "left", True)
  1293. assert_array_almost_equal(qc, q)
  1294. def test_simple_fat_right(self):
  1295. a = [[8, 2, 3], [2, 9, 5]]
  1296. q, r = qr(a, mode="economic")
  1297. c = [1, 2]
  1298. cq, r2 = qr_multiply(a, c)
  1299. assert_array_almost_equal(c @ q, cq)
  1300. assert_array_almost_equal(r, r2)
  1301. cq, r = qr_multiply(a, eye(2))
  1302. assert_array_almost_equal(cq, q)
  1303. def test_simple_fat_right_pivoting(self):
  1304. a = [[8, 2, 3], [2, 9, 5]]
  1305. q, r, jpvt = qr(a, pivoting=True, mode="economic")
  1306. c = [1, 2]
  1307. cq, r, jpvt = qr_multiply(a, c, pivoting=True)
  1308. assert_array_almost_equal(c @ q, cq)
  1309. cq, r, jpvt = qr_multiply(a, eye(2), pivoting=True)
  1310. assert_array_almost_equal(cq, q)
  1311. def test_simple_complex(self):
  1312. a = [[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]]
  1313. q, r = qr(a)
  1314. assert_array_almost_equal(q.conj().T @ q, eye(3))
  1315. assert_array_almost_equal(q @ r, a)
  1316. def test_simple_complex_left(self):
  1317. a = [[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]]
  1318. q, r = qr(a)
  1319. c = [1, 2, 3+4j]
  1320. qc, r = qr_multiply(a, c, "left")
  1321. assert_array_almost_equal(q @ c, qc)
  1322. qc, r = qr_multiply(a, eye(3), "left")
  1323. assert_array_almost_equal(q, qc)
  1324. def test_simple_complex_right(self):
  1325. a = [[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]]
  1326. q, r = qr(a)
  1327. c = [1, 2, 3+4j]
  1328. qc, r = qr_multiply(a, c)
  1329. assert_array_almost_equal(c @ q, qc)
  1330. qc, r = qr_multiply(a, eye(3))
  1331. assert_array_almost_equal(q, qc)
  1332. def test_simple_tall_complex_left(self):
  1333. a = [[8, 2+3j], [2, 9], [5+7j, 3]]
  1334. q, r = qr(a, mode="economic")
  1335. c = [1, 2+2j]
  1336. qc, r2 = qr_multiply(a, c, "left")
  1337. assert_array_almost_equal(q @ c, qc)
  1338. assert_array_almost_equal(r, r2)
  1339. c = array([1, 2, 0])
  1340. qc, r2 = qr_multiply(a, c, "left", overwrite_c=True)
  1341. assert_array_almost_equal(q @ c[:2], qc)
  1342. qc, r = qr_multiply(a, eye(2), "left")
  1343. assert_array_almost_equal(qc, q)
  1344. def test_simple_complex_left_conjugate(self):
  1345. a = [[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]]
  1346. q, r = qr(a)
  1347. c = [1, 2, 3+4j]
  1348. qc, r = qr_multiply(a, c, "left", conjugate=True)
  1349. assert_array_almost_equal(q.conj() @ c, qc)
  1350. def test_simple_complex_tall_left_conjugate(self):
  1351. a = [[3, 3+4j], [5, 2+2j], [3, 2]]
  1352. q, r = qr(a, mode='economic')
  1353. c = [1, 3+4j]
  1354. qc, r = qr_multiply(a, c, "left", conjugate=True)
  1355. assert_array_almost_equal(q.conj() @ c, qc)
  1356. def test_simple_complex_right_conjugate(self):
  1357. a = [[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]]
  1358. q, r = qr(a)
  1359. c = np.array([1, 2, 3+4j])
  1360. qc, r = qr_multiply(a, c, conjugate=True)
  1361. assert_array_almost_equal(c @ q.conj(), qc)
  1362. def test_simple_complex_pivoting(self):
  1363. a = array([[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]])
  1364. q, r, p = qr(a, pivoting=True)
  1365. d = abs(diag(r))
  1366. assert_(np.all(d[1:] <= d[:-1]))
  1367. assert_array_almost_equal(q.conj().T @ q, eye(3))
  1368. assert_array_almost_equal(q @ r, a[:, p])
  1369. q2, r2 = qr(a[:, p])
  1370. assert_array_almost_equal(q, q2)
  1371. assert_array_almost_equal(r, r2)
  1372. def test_simple_complex_left_pivoting(self):
  1373. a = array([[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]])
  1374. q, r, jpvt = qr(a, pivoting=True)
  1375. c = [1, 2, 3+4j]
  1376. qc, r, jpvt = qr_multiply(a, c, "left", True)
  1377. assert_array_almost_equal(q @ c, qc)
  1378. def test_simple_complex_right_pivoting(self):
  1379. a = array([[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]])
  1380. q, r, jpvt = qr(a, pivoting=True)
  1381. c = [1, 2, 3+4j]
  1382. qc, r, jpvt = qr_multiply(a, c, pivoting=True)
  1383. assert_array_almost_equal(c @ q, qc)
  1384. def test_random(self):
  1385. rng = np.random.RandomState(1234)
  1386. n = 20
  1387. for k in range(2):
  1388. a = rng.random([n, n])
  1389. q, r = qr(a)
  1390. assert_array_almost_equal(q.T @ q, eye(n))
  1391. assert_array_almost_equal(q @ r, a)
  1392. def test_random_left(self):
  1393. rng = np.random.RandomState(1234)
  1394. n = 20
  1395. for k in range(2):
  1396. a = rng.random([n, n])
  1397. q, r = qr(a)
  1398. c = rng.random([n])
  1399. qc, r = qr_multiply(a, c, "left")
  1400. assert_array_almost_equal(q @ c, qc)
  1401. qc, r = qr_multiply(a, eye(n), "left")
  1402. assert_array_almost_equal(q, qc)
  1403. def test_random_right(self):
  1404. rng = np.random.RandomState(1234)
  1405. n = 20
  1406. for k in range(2):
  1407. a = rng.random([n, n])
  1408. q, r = qr(a)
  1409. c = rng.random([n])
  1410. cq, r = qr_multiply(a, c)
  1411. assert_array_almost_equal(c @ q, cq)
  1412. cq, r = qr_multiply(a, eye(n))
  1413. assert_array_almost_equal(q, cq)
  1414. def test_random_pivoting(self):
  1415. rng = np.random.RandomState(1234)
  1416. n = 20
  1417. for k in range(2):
  1418. a = rng.random([n, n])
  1419. q, r, p = qr(a, pivoting=True)
  1420. d = abs(diag(r))
  1421. assert_(np.all(d[1:] <= d[:-1]))
  1422. assert_array_almost_equal(q.T @ q, eye(n))
  1423. assert_array_almost_equal(q @ r, a[:, p])
  1424. q2, r2 = qr(a[:, p])
  1425. assert_array_almost_equal(q, q2)
  1426. assert_array_almost_equal(r, r2)
  1427. def test_random_tall(self):
  1428. rng = np.random.RandomState(1234)
  1429. # full version
  1430. m = 200
  1431. n = 100
  1432. for k in range(2):
  1433. a = rng.random([m, n])
  1434. q, r = qr(a)
  1435. assert_array_almost_equal(q.T @ q, eye(m))
  1436. assert_array_almost_equal(q @ r, a)
  1437. def test_random_tall_left(self):
  1438. rng = np.random.RandomState(1234)
  1439. # full version
  1440. m = 200
  1441. n = 100
  1442. for k in range(2):
  1443. a = rng.random([m, n])
  1444. q, r = qr(a, mode="economic")
  1445. c = rng.random([n])
  1446. qc, r = qr_multiply(a, c, "left")
  1447. assert_array_almost_equal(q @ c, qc)
  1448. qc, r = qr_multiply(a, eye(n), "left")
  1449. assert_array_almost_equal(qc, q)
  1450. def test_random_tall_right(self):
  1451. rng = np.random.RandomState(1234)
  1452. # full version
  1453. m = 200
  1454. n = 100
  1455. for k in range(2):
  1456. a = rng.random([m, n])
  1457. q, r = qr(a, mode="economic")
  1458. c = rng.random([m])
  1459. cq, r = qr_multiply(a, c)
  1460. assert_array_almost_equal(c @ q, cq)
  1461. cq, r = qr_multiply(a, eye(m))
  1462. assert_array_almost_equal(cq, q)
  1463. def test_random_tall_pivoting(self):
  1464. rng = np.random.RandomState(1234)
  1465. # full version pivoting
  1466. m = 200
  1467. n = 100
  1468. for k in range(2):
  1469. a = rng.random([m, n])
  1470. q, r, p = qr(a, pivoting=True)
  1471. d = abs(diag(r))
  1472. assert_(np.all(d[1:] <= d[:-1]))
  1473. assert_array_almost_equal(q.T @ q, eye(m))
  1474. assert_array_almost_equal(q @ r, a[:, p])
  1475. q2, r2 = qr(a[:, p])
  1476. assert_array_almost_equal(q, q2)
  1477. assert_array_almost_equal(r, r2)
  1478. def test_random_tall_e(self):
  1479. rng = np.random.RandomState(1234)
  1480. # economy version
  1481. m = 200
  1482. n = 100
  1483. for k in range(2):
  1484. a = rng.random([m, n])
  1485. q, r = qr(a, mode='economic')
  1486. assert_array_almost_equal(q.T @ q, eye(n))
  1487. assert_array_almost_equal(q @ r, a)
  1488. assert_equal(q.shape, (m, n))
  1489. assert_equal(r.shape, (n, n))
  1490. def test_random_tall_e_pivoting(self):
  1491. rng = np.random.RandomState(1234)
  1492. # economy version pivoting
  1493. m = 200
  1494. n = 100
  1495. for k in range(2):
  1496. a = rng.random([m, n])
  1497. q, r, p = qr(a, pivoting=True, mode='economic')
  1498. d = abs(diag(r))
  1499. assert_(np.all(d[1:] <= d[:-1]))
  1500. assert_array_almost_equal(q.T @ q, eye(n))
  1501. assert_array_almost_equal(q @ r, a[:, p])
  1502. assert_equal(q.shape, (m, n))
  1503. assert_equal(r.shape, (n, n))
  1504. q2, r2 = qr(a[:, p], mode='economic')
  1505. assert_array_almost_equal(q, q2)
  1506. assert_array_almost_equal(r, r2)
  1507. def test_random_trap(self):
  1508. rng = np.random.RandomState(1234)
  1509. m = 100
  1510. n = 200
  1511. for k in range(2):
  1512. a = rng.random([m, n])
  1513. q, r = qr(a)
  1514. assert_array_almost_equal(q.T @ q, eye(m))
  1515. assert_array_almost_equal(q @ r, a)
  1516. def test_random_trap_pivoting(self):
  1517. rng = np.random.RandomState(1234)
  1518. m = 100
  1519. n = 200
  1520. for k in range(2):
  1521. a = rng.random([m, n])
  1522. q, r, p = qr(a, pivoting=True)
  1523. d = abs(diag(r))
  1524. assert_(np.all(d[1:] <= d[:-1]))
  1525. assert_array_almost_equal(q.T @ q, eye(m))
  1526. assert_array_almost_equal(q @ r, a[:, p])
  1527. q2, r2 = qr(a[:, p])
  1528. assert_array_almost_equal(q, q2)
  1529. assert_array_almost_equal(r, r2)
  1530. def test_random_complex(self):
  1531. rng = np.random.RandomState(1234)
  1532. n = 20
  1533. for k in range(2):
  1534. a = rng.random([n, n]) + 1j*rng.random([n, n])
  1535. q, r = qr(a)
  1536. assert_array_almost_equal(q.conj().T @ q, eye(n))
  1537. assert_array_almost_equal(q @ r, a)
  1538. def test_random_complex_left(self):
  1539. rng = np.random.RandomState(1234)
  1540. n = 20
  1541. for k in range(2):
  1542. a = rng.random([n, n]) + 1j*rng.random([n, n])
  1543. q, r = qr(a)
  1544. c = rng.random([n]) + 1j*rng.random([n])
  1545. qc, r = qr_multiply(a, c, "left")
  1546. assert_array_almost_equal(q @ c, qc)
  1547. qc, r = qr_multiply(a, eye(n), "left")
  1548. assert_array_almost_equal(q, qc)
  1549. def test_random_complex_right(self):
  1550. rng = np.random.RandomState(1234)
  1551. n = 20
  1552. for k in range(2):
  1553. a = rng.random([n, n]) + 1j*rng.random([n, n])
  1554. q, r = qr(a)
  1555. c = rng.random([n]) + 1j*rng.random([n])
  1556. cq, r = qr_multiply(a, c)
  1557. assert_array_almost_equal(c @ q, cq)
  1558. cq, r = qr_multiply(a, eye(n))
  1559. assert_array_almost_equal(q, cq)
  1560. def test_random_complex_pivoting(self):
  1561. rng = np.random.RandomState(1234)
  1562. n = 20
  1563. for k in range(2):
  1564. a = rng.random([n, n]) + 1j*rng.random([n, n])
  1565. q, r, p = qr(a, pivoting=True)
  1566. d = abs(diag(r))
  1567. assert_(np.all(d[1:] <= d[:-1]))
  1568. assert_array_almost_equal(q.conj().T @ q, eye(n))
  1569. assert_array_almost_equal(q @ r, a[:, p])
  1570. q2, r2 = qr(a[:, p])
  1571. assert_array_almost_equal(q, q2)
  1572. assert_array_almost_equal(r, r2)
  1573. def test_check_finite(self):
  1574. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1575. q, r = qr(a, check_finite=False)
  1576. assert_array_almost_equal(q.T @ q, eye(3))
  1577. assert_array_almost_equal(q @ r, a)
  1578. def test_lwork(self):
  1579. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1580. # Get comparison values
  1581. q, r = qr(a, lwork=None)
  1582. # Test against minimum valid lwork
  1583. q2, r2 = qr(a, lwork=3)
  1584. assert_array_almost_equal(q2, q)
  1585. assert_array_almost_equal(r2, r)
  1586. # Test against larger lwork
  1587. q3, r3 = qr(a, lwork=10)
  1588. assert_array_almost_equal(q3, q)
  1589. assert_array_almost_equal(r3, r)
  1590. # Test against explicit lwork=-1
  1591. q4, r4 = qr(a, lwork=-1)
  1592. assert_array_almost_equal(q4, q)
  1593. assert_array_almost_equal(r4, r)
  1594. # Test against invalid lwork
  1595. assert_raises(Exception, qr, (a,), {'lwork': 0})
  1596. assert_raises(Exception, qr, (a,), {'lwork': 2})
  1597. @pytest.mark.parametrize("m", [0, 1, 2])
  1598. @pytest.mark.parametrize("n", [0, 1, 2])
  1599. @pytest.mark.parametrize("pivoting", [False, True])
  1600. @pytest.mark.parametrize('dtype', DTYPES)
  1601. def test_shape_dtype(self, m, n, pivoting, dtype):
  1602. k = min(m, n)
  1603. a = np.zeros((m, n), dtype=dtype)
  1604. q, r, *other = qr(a, pivoting=pivoting)
  1605. assert_equal(q.shape, (m, m))
  1606. assert_equal(q.dtype, dtype)
  1607. assert_equal(r.shape, (m, n))
  1608. assert_equal(r.dtype, dtype)
  1609. assert len(other) == (1 if pivoting else 0)
  1610. if pivoting:
  1611. p, = other
  1612. assert_equal(p.shape, (n,))
  1613. assert_equal(p.dtype, np.int32)
  1614. r, *other = qr(a, mode='r', pivoting=pivoting)
  1615. assert_equal(r.shape, (m, n))
  1616. assert_equal(r.dtype, dtype)
  1617. assert len(other) == (1 if pivoting else 0)
  1618. if pivoting:
  1619. p, = other
  1620. assert_equal(p.shape, (n,))
  1621. assert_equal(p.dtype, np.int32)
  1622. q, r, *other = qr(a, mode='economic', pivoting=pivoting)
  1623. assert_equal(q.shape, (m, k))
  1624. assert_equal(q.dtype, dtype)
  1625. assert_equal(r.shape, (k, n))
  1626. assert_equal(r.dtype, dtype)
  1627. assert len(other) == (1 if pivoting else 0)
  1628. if pivoting:
  1629. p, = other
  1630. assert_equal(p.shape, (n,))
  1631. assert_equal(p.dtype, np.int32)
  1632. (raw, tau), r, *other = qr(a, mode='raw', pivoting=pivoting)
  1633. assert_equal(raw.shape, (m, n))
  1634. assert_equal(raw.dtype, dtype)
  1635. assert_equal(tau.shape, (k,))
  1636. assert_equal(tau.dtype, dtype)
  1637. assert_equal(r.shape, (k, n))
  1638. assert_equal(r.dtype, dtype)
  1639. assert len(other) == (1 if pivoting else 0)
  1640. if pivoting:
  1641. p, = other
  1642. assert_equal(p.shape, (n,))
  1643. assert_equal(p.dtype, np.int32)
  1644. @pytest.mark.parametrize(("m", "n"), [(0, 0), (0, 2), (2, 0)])
  1645. def test_empty(self, m, n):
  1646. k = min(m, n)
  1647. a = np.empty((m, n))
  1648. q, r = qr(a)
  1649. assert_allclose(q, np.identity(m))
  1650. assert_allclose(r, np.empty((m, n)))
  1651. q, r, p = qr(a, pivoting=True)
  1652. assert_allclose(q, np.identity(m))
  1653. assert_allclose(r, np.empty((m, n)))
  1654. assert_allclose(p, np.arange(n))
  1655. r, = qr(a, mode='r')
  1656. assert_allclose(r, np.empty((m, n)))
  1657. q, r = qr(a, mode='economic')
  1658. assert_allclose(q, np.empty((m, k)))
  1659. assert_allclose(r, np.empty((k, n)))
  1660. (raw, tau), r = qr(a, mode='raw')
  1661. assert_allclose(raw, np.empty((m, n)))
  1662. assert_allclose(tau, np.empty((k,)))
  1663. assert_allclose(r, np.empty((k, n)))
  1664. def test_multiply_empty(self):
  1665. a = np.empty((0, 0))
  1666. c = np.empty((0, 0))
  1667. cq, r = qr_multiply(a, c)
  1668. assert_allclose(cq, np.empty((0, 0)))
  1669. a = np.empty((0, 2))
  1670. c = np.empty((2, 0))
  1671. cq, r = qr_multiply(a, c)
  1672. assert_allclose(cq, np.empty((2, 0)))
  1673. a = np.empty((2, 0))
  1674. c = np.empty((0, 2))
  1675. cq, r = qr_multiply(a, c)
  1676. assert_allclose(cq, np.empty((0, 2)))
  1677. class TestRQ:
  1678. def test_simple(self):
  1679. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1680. r, q = rq(a)
  1681. assert_array_almost_equal(q @ q.T, eye(3))
  1682. assert_array_almost_equal(r @ q, a)
  1683. def test_r(self):
  1684. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1685. r, q = rq(a)
  1686. r2 = rq(a, mode='r')
  1687. assert_array_almost_equal(r, r2)
  1688. def test_random(self):
  1689. rng = np.random.RandomState(1234)
  1690. n = 20
  1691. for k in range(2):
  1692. a = rng.random([n, n])
  1693. r, q = rq(a)
  1694. assert_array_almost_equal(q @ q.T, eye(n))
  1695. assert_array_almost_equal(r @ q, a)
  1696. def test_simple_trap(self):
  1697. a = [[8, 2, 3], [2, 9, 3]]
  1698. r, q = rq(a)
  1699. assert_array_almost_equal(q.T @ q, eye(3))
  1700. assert_array_almost_equal(r @ q, a)
  1701. def test_simple_tall(self):
  1702. a = [[8, 2], [2, 9], [5, 3]]
  1703. r, q = rq(a)
  1704. assert_array_almost_equal(q.T @ q, eye(2))
  1705. assert_array_almost_equal(r @ q, a)
  1706. def test_simple_fat(self):
  1707. a = [[8, 2, 5], [2, 9, 3]]
  1708. r, q = rq(a)
  1709. assert_array_almost_equal(q @ q.T, eye(3))
  1710. assert_array_almost_equal(r @ q, a)
  1711. def test_simple_complex(self):
  1712. a = [[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]]
  1713. r, q = rq(a)
  1714. assert_array_almost_equal(q @ q.conj().T, eye(3))
  1715. assert_array_almost_equal(r @ q, a)
  1716. def test_random_tall(self):
  1717. rng = np.random.RandomState(1234)
  1718. m = 200
  1719. n = 100
  1720. for k in range(2):
  1721. a = rng.random([m, n])
  1722. r, q = rq(a)
  1723. assert_array_almost_equal(q @ q.T, eye(n))
  1724. assert_array_almost_equal(r @ q, a)
  1725. def test_random_trap(self):
  1726. rng = np.random.RandomState(1234)
  1727. m = 100
  1728. n = 200
  1729. for k in range(2):
  1730. a = rng.random([m, n])
  1731. r, q = rq(a)
  1732. assert_array_almost_equal(q @ q.T, eye(n))
  1733. assert_array_almost_equal(r @ q, a)
  1734. def test_random_trap_economic(self):
  1735. rng = np.random.RandomState(1234)
  1736. m = 100
  1737. n = 200
  1738. for k in range(2):
  1739. a = rng.random([m, n])
  1740. r, q = rq(a, mode='economic')
  1741. assert_array_almost_equal(q @ q.T, eye(m))
  1742. assert_array_almost_equal(r @ q, a)
  1743. assert_equal(q.shape, (m, n))
  1744. assert_equal(r.shape, (m, m))
  1745. def test_random_complex(self):
  1746. rng = np.random.RandomState(1234)
  1747. n = 20
  1748. for k in range(2):
  1749. a = rng.random([n, n]) + 1j*rng.random([n, n])
  1750. r, q = rq(a)
  1751. assert_array_almost_equal(q @ q.conj().T, eye(n))
  1752. assert_array_almost_equal(r @ q, a)
  1753. def test_random_complex_economic(self):
  1754. rng = np.random.RandomState(1234)
  1755. m = 100
  1756. n = 200
  1757. for k in range(2):
  1758. a = rng.random([m, n]) + 1j*rng.random([m, n])
  1759. r, q = rq(a, mode='economic')
  1760. assert_array_almost_equal(q @ q.conj().T, eye(m))
  1761. assert_array_almost_equal(r @ q, a)
  1762. assert_equal(q.shape, (m, n))
  1763. assert_equal(r.shape, (m, m))
  1764. def test_check_finite(self):
  1765. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1766. r, q = rq(a, check_finite=False)
  1767. assert_array_almost_equal(q @ q.T, eye(3))
  1768. assert_array_almost_equal(r @ q, a)
  1769. @pytest.mark.parametrize("m", [0, 1, 2])
  1770. @pytest.mark.parametrize("n", [0, 1, 2])
  1771. @pytest.mark.parametrize('dtype', DTYPES)
  1772. def test_shape_dtype(self, m, n, dtype):
  1773. k = min(m, n)
  1774. a = np.zeros((m, n), dtype=dtype)
  1775. r, q = rq(a)
  1776. assert_equal(q.shape, (n, n))
  1777. assert_equal(r.shape, (m, n))
  1778. assert_equal(r.dtype, dtype)
  1779. assert_equal(q.dtype, dtype)
  1780. r = rq(a, mode='r')
  1781. assert_equal(r.shape, (m, n))
  1782. assert_equal(r.dtype, dtype)
  1783. r, q = rq(a, mode='economic')
  1784. assert_equal(r.shape, (m, k))
  1785. assert_equal(r.dtype, dtype)
  1786. assert_equal(q.shape, (k, n))
  1787. assert_equal(q.dtype, dtype)
  1788. @pytest.mark.parametrize(("m", "n"), [(0, 0), (0, 2), (2, 0)])
  1789. def test_empty(self, m, n):
  1790. k = min(m, n)
  1791. a = np.empty((m, n))
  1792. r, q = rq(a)
  1793. assert_allclose(r, np.empty((m, n)))
  1794. assert_allclose(q, np.identity(n))
  1795. r = rq(a, mode='r')
  1796. assert_allclose(r, np.empty((m, n)))
  1797. r, q = rq(a, mode='economic')
  1798. assert_allclose(r, np.empty((m, k)))
  1799. assert_allclose(q, np.empty((k, n)))
  1800. class TestSchur:
  1801. def check_schur(self, a, t, u, rtol, atol):
  1802. # Check that the Schur decomposition is correct.
  1803. assert_allclose(u @ t @ u.conj().T, a, rtol=rtol, atol=atol,
  1804. err_msg="Schur decomposition does not match 'a'")
  1805. # The expected value of u @ u.H - I is all zeros, so test
  1806. # with absolute tolerance only.
  1807. assert_allclose(u @ u.conj().T - np.eye(len(u)), 0, rtol=0, atol=atol,
  1808. err_msg="u is not unitary")
  1809. def test_simple(self):
  1810. a = [[8, 12, 3], [2, 9, 3], [10, 3, 6]]
  1811. t, z = schur(a)
  1812. self.check_schur(a, t, z, rtol=1e-14, atol=5e-15)
  1813. tc, zc = schur(a, 'complex')
  1814. assert_(np.any(ravel(iscomplex(zc))) and np.any(ravel(iscomplex(tc))))
  1815. self.check_schur(a, tc, zc, rtol=1e-14, atol=5e-15)
  1816. tc2, zc2 = rsf2csf(tc, zc)
  1817. self.check_schur(a, tc2, zc2, rtol=1e-14, atol=5e-15)
  1818. @pytest.mark.parametrize(
  1819. 'sort, expected_diag',
  1820. [('lhp', [-np.sqrt(2), -0.5, np.sqrt(2), 0.5]),
  1821. ('rhp', [np.sqrt(2), 0.5, -np.sqrt(2), -0.5]),
  1822. ('iuc', [-0.5, 0.5, np.sqrt(2), -np.sqrt(2)]),
  1823. ('ouc', [np.sqrt(2), -np.sqrt(2), -0.5, 0.5]),
  1824. (lambda x: x >= 0.0, [np.sqrt(2), 0.5, -np.sqrt(2), -0.5])]
  1825. )
  1826. def test_sort(self, sort, expected_diag):
  1827. # The exact eigenvalues of this matrix are
  1828. # -sqrt(2), sqrt(2), -1/2, 1/2.
  1829. a = [[4., 3., 1., -1.],
  1830. [-4.5, -3.5, -1., 1.],
  1831. [9., 6., -4., 4.5],
  1832. [6., 4., -3., 3.5]]
  1833. t, u, sdim = schur(a, sort=sort)
  1834. self.check_schur(a, t, u, rtol=1e-14, atol=5e-15)
  1835. assert_allclose(np.diag(t), expected_diag, rtol=1e-12)
  1836. assert_equal(2, sdim)
  1837. def test_sort_errors(self):
  1838. a = [[4., 3., 1., -1.],
  1839. [-4.5, -3.5, -1., 1.],
  1840. [9., 6., -4., 4.5],
  1841. [6., 4., -3., 3.5]]
  1842. assert_raises(ValueError, schur, a, sort='unsupported')
  1843. assert_raises(ValueError, schur, a, sort=1)
  1844. def test_check_finite(self):
  1845. a = [[8, 12, 3], [2, 9, 3], [10, 3, 6]]
  1846. t, z = schur(a, check_finite=False)
  1847. assert_array_almost_equal(z @ t @ z.conj().T, a)
  1848. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  1849. def test_empty(self, dt):
  1850. a = np.empty((0, 0), dtype=dt)
  1851. t, z = schur(a)
  1852. t0, z0 = schur(np.eye(2, dtype=dt))
  1853. assert_allclose(t, np.empty((0, 0)))
  1854. assert_allclose(z, np.empty((0, 0)))
  1855. assert t.dtype == t0.dtype
  1856. assert z.dtype == z0.dtype
  1857. t, z, sdim = schur(a, sort='lhp')
  1858. assert_allclose(t, np.empty((0, 0)))
  1859. assert_allclose(z, np.empty((0, 0)))
  1860. assert_equal(sdim, 0)
  1861. assert t.dtype == t0.dtype
  1862. assert z.dtype == z0.dtype
  1863. @pytest.mark.parametrize('sort', ['iuc', 'ouc'])
  1864. @pytest.mark.parametrize('output', ['real', 'complex'])
  1865. @pytest.mark.parametrize('dtype', [np.float32, np.float64,
  1866. np.complex64, np.complex128])
  1867. def test_gh_13137_sort_str(self, sort, output, dtype):
  1868. # gh-13137 reported that sort values 'iuc' and 'ouc' were not
  1869. # correct because the callables assumed that the eigenvalues would
  1870. # always be expressed as a single complex number.
  1871. # In fact, when `output='real'` and the dtype is real, the
  1872. # eigenvalues are passed as separate real and imaginary components
  1873. # (yet no error is raised if the callable accepts only one argument).
  1874. #
  1875. # This tests these sort values by counting the number of eigenvalues
  1876. # `schur` reports as being inside/outside the unit circle.
  1877. # Real matrix with eigenvalues 0.1 +- 2j
  1878. A = np.asarray([[0.1, -2], [2, 0.1]])
  1879. # Previously, this would fail for `output='real'` with real dtypes
  1880. sdim = schur(A.astype(dtype), sort=sort, output=output)[-1]
  1881. assert sdim == 0 if sort == 'iuc' else sdim == 2
  1882. @pytest.mark.parametrize('output', ['real', 'complex'])
  1883. @pytest.mark.parametrize('dtype', [np.float32, np.float64,
  1884. np.complex64, np.complex128])
  1885. def test_gh_13137_sort_custom(self, output, dtype):
  1886. # This simply tests our understanding of how eigenvalues are
  1887. # passed to a sort callable. If `output='real'` and the dtype is real,
  1888. # real and imaginary parts are passed as separate real arguments;
  1889. # otherwise, they are passed a single complex argument.
  1890. # Also, if `output='real'` and the dtype is real, when either
  1891. # eigenvalue in a complex conjugate pair satisfies the sort condition,
  1892. # `sdim` is incremented by TWO.
  1893. # Real matrix with eigenvalues 0.1 +- 2j
  1894. A = np.asarray([[0.1, -2], [2, 0.1]])
  1895. all_real = output=='real' and dtype in {np.float32, np.float64}
  1896. def sort(x, y=None):
  1897. if all_real:
  1898. assert not np.iscomplexobj(x)
  1899. assert y is not None and np.isreal(y)
  1900. z = x + y*1j
  1901. else:
  1902. assert np.iscomplexobj(x)
  1903. assert y is None
  1904. z = x
  1905. return z.imag > 1e-15
  1906. # Only one complex eigenvalue satisfies the condition, but when
  1907. # `all_real` applies, both eigenvalues in the complex conjugate pair
  1908. # are counted.
  1909. sdim = schur(A.astype(dtype), sort=sort, output=output)[-1]
  1910. assert sdim == 2 if all_real else sdim == 1
  1911. class TestHessenberg:
  1912. def test_simple(self):
  1913. a = [[-149, -50, -154],
  1914. [537, 180, 546],
  1915. [-27, -9, -25]]
  1916. h1 = [[-149.0000, 42.2037, -156.3165],
  1917. [-537.6783, 152.5511, -554.9272],
  1918. [0, 0.0728, 2.4489]]
  1919. h, q = hessenberg(a, calc_q=1)
  1920. assert_array_almost_equal(q.T @ a @ q, h)
  1921. assert_array_almost_equal(h, h1, decimal=4)
  1922. def test_simple_complex(self):
  1923. a = [[-149, -50, -154],
  1924. [537, 180j, 546],
  1925. [-27j, -9, -25]]
  1926. h, q = hessenberg(a, calc_q=1)
  1927. assert_array_almost_equal(q.conj().T @ a @ q, h)
  1928. def test_simple2(self):
  1929. a = [[1, 2, 3, 4, 5, 6, 7],
  1930. [0, 2, 3, 4, 6, 7, 2],
  1931. [0, 2, 2, 3, 0, 3, 2],
  1932. [0, 0, 2, 8, 0, 0, 2],
  1933. [0, 3, 1, 2, 0, 1, 2],
  1934. [0, 1, 2, 3, 0, 1, 0],
  1935. [0, 0, 0, 0, 0, 1, 2]]
  1936. h, q = hessenberg(a, calc_q=1)
  1937. assert_array_almost_equal(q.T @ a @ q, h)
  1938. def test_simple3(self):
  1939. a = np.eye(3)
  1940. a[-1, 0] = 2
  1941. h, q = hessenberg(a, calc_q=1)
  1942. assert_array_almost_equal(q.T @ a @ q, h)
  1943. def test_random(self):
  1944. rng = np.random.RandomState(1234)
  1945. n = 20
  1946. for k in range(2):
  1947. a = rng.random([n, n])
  1948. h, q = hessenberg(a, calc_q=1)
  1949. assert_array_almost_equal(q.T @ a @ q, h)
  1950. def test_random_complex(self):
  1951. rng = np.random.RandomState(1234)
  1952. n = 20
  1953. for k in range(2):
  1954. a = rng.random([n, n]) + 1j*rng.random([n, n])
  1955. h, q = hessenberg(a, calc_q=1)
  1956. assert_array_almost_equal(q.conj().T @ a @ q, h)
  1957. def test_check_finite(self):
  1958. a = [[-149, -50, -154],
  1959. [537, 180, 546],
  1960. [-27, -9, -25]]
  1961. h1 = [[-149.0000, 42.2037, -156.3165],
  1962. [-537.6783, 152.5511, -554.9272],
  1963. [0, 0.0728, 2.4489]]
  1964. h, q = hessenberg(a, calc_q=1, check_finite=False)
  1965. assert_array_almost_equal(q.T @ a @ q, h)
  1966. assert_array_almost_equal(h, h1, decimal=4)
  1967. def test_2x2(self):
  1968. a = [[2, 1], [7, 12]]
  1969. h, q = hessenberg(a, calc_q=1)
  1970. assert_array_almost_equal(q, np.eye(2))
  1971. assert_array_almost_equal(h, a)
  1972. b = [[2-7j, 1+2j], [7+3j, 12-2j]]
  1973. h2, q2 = hessenberg(b, calc_q=1)
  1974. assert_array_almost_equal(q2, np.eye(2))
  1975. assert_array_almost_equal(h2, b)
  1976. @pytest.mark.parametrize('dt', [int, float, float32, complex, complex64])
  1977. def test_empty(self, dt):
  1978. a = np.empty((0, 0), dtype=dt)
  1979. h = hessenberg(a)
  1980. assert h.shape == (0, 0)
  1981. assert h.dtype == hessenberg(np.eye(3, dtype=dt)).dtype
  1982. h, q = hessenberg(a, calc_q=True)
  1983. h3, q3 = hessenberg(a, calc_q=True)
  1984. assert h.shape == (0, 0)
  1985. assert h.dtype == h3.dtype
  1986. assert q.shape == (0, 0)
  1987. assert q.dtype == q3.dtype
  1988. blas_provider = blas_version = None
  1989. blas_provider = CONFIG['Build Dependencies']['blas']['name']
  1990. blas_version = CONFIG['Build Dependencies']['blas']['version']
  1991. class TestQZ:
  1992. def test_qz_single(self):
  1993. rng = np.random.RandomState(12345)
  1994. n = 5
  1995. A = rng.random([n, n]).astype(float32)
  1996. B = rng.random([n, n]).astype(float32)
  1997. AA, BB, Q, Z = qz(A, B)
  1998. assert_array_almost_equal(Q @ AA @ Z.T, A, decimal=5)
  1999. assert_array_almost_equal(Q @ BB @ Z.T, B, decimal=5)
  2000. assert_array_almost_equal(Q @ Q.T, eye(n), decimal=5)
  2001. assert_array_almost_equal(Z @ Z.T, eye(n), decimal=5)
  2002. assert_(np.all(diag(BB) >= 0))
  2003. def test_qz_double(self):
  2004. rng = np.random.RandomState(12345)
  2005. n = 5
  2006. A = rng.random([n, n])
  2007. B = rng.random([n, n])
  2008. AA, BB, Q, Z = qz(A, B)
  2009. assert_array_almost_equal(Q @ AA @ Z.T, A)
  2010. assert_array_almost_equal(Q @ BB @ Z.T, B)
  2011. assert_array_almost_equal(Q @ Q.T, eye(n))
  2012. assert_array_almost_equal(Z @ Z.T, eye(n))
  2013. assert_(np.all(diag(BB) >= 0))
  2014. def test_qz_complex(self):
  2015. rng = np.random.RandomState(12345)
  2016. n = 5
  2017. A = rng.random([n, n]) + 1j*rng.random([n, n])
  2018. B = rng.random([n, n]) + 1j*rng.random([n, n])
  2019. AA, BB, Q, Z = qz(A, B)
  2020. assert_array_almost_equal(Q @ AA @ Z.conj().T, A)
  2021. assert_array_almost_equal(Q @ BB @ Z.conj().T, B)
  2022. assert_array_almost_equal(Q @ Q.conj().T, eye(n))
  2023. assert_array_almost_equal(Z @ Z.conj().T, eye(n))
  2024. assert_(np.all(diag(BB) >= 0))
  2025. assert_(np.all(diag(BB).imag == 0))
  2026. def test_qz_complex64(self):
  2027. rng = np.random.RandomState(12345)
  2028. n = 5
  2029. A = (rng.random([n, n]) + 1j*rng.random([n, n])).astype(complex64)
  2030. B = (rng.random([n, n]) + 1j*rng.random([n, n])).astype(complex64)
  2031. AA, BB, Q, Z = qz(A, B)
  2032. assert_array_almost_equal(Q @ AA @ Z.conj().T, A, decimal=5)
  2033. assert_array_almost_equal(Q @ BB @ Z.conj().T, B, decimal=5)
  2034. assert_array_almost_equal(Q @ Q.conj().T, eye(n), decimal=5)
  2035. assert_array_almost_equal(Z @ Z.conj().T, eye(n), decimal=5)
  2036. assert_(np.all(diag(BB) >= 0))
  2037. assert_(np.all(diag(BB).imag == 0))
  2038. def test_qz_double_complex(self):
  2039. rng = np.random.RandomState(12345)
  2040. n = 5
  2041. A = rng.random([n, n])
  2042. B = rng.random([n, n])
  2043. AA, BB, Q, Z = qz(A, B, output='complex')
  2044. aa = Q @ AA @ Z.conj().T
  2045. assert_array_almost_equal(aa.real, A)
  2046. assert_array_almost_equal(aa.imag, 0)
  2047. bb = Q @ BB @ Z.conj().T
  2048. assert_array_almost_equal(bb.real, B)
  2049. assert_array_almost_equal(bb.imag, 0)
  2050. assert_array_almost_equal(Q @ Q.conj().T, eye(n))
  2051. assert_array_almost_equal(Z @ Z.conj().T, eye(n))
  2052. assert_(np.all(diag(BB) >= 0))
  2053. def test_qz_double_sort(self):
  2054. # from https://www.nag.com/lapack-ex/node119.html
  2055. # NOTE: These matrices may be ill-conditioned and lead to a
  2056. # seg fault on certain python versions when compiled with
  2057. # sse2 or sse3 older ATLAS/LAPACK binaries for windows
  2058. # A = np.array([[3.9, 12.5, -34.5, -0.5],
  2059. # [ 4.3, 21.5, -47.5, 7.5],
  2060. # [ 4.3, 21.5, -43.5, 3.5],
  2061. # [ 4.4, 26.0, -46.0, 6.0 ]])
  2062. # B = np.array([[ 1.0, 2.0, -3.0, 1.0],
  2063. # [1.0, 3.0, -5.0, 4.0],
  2064. # [1.0, 3.0, -4.0, 3.0],
  2065. # [1.0, 3.0, -4.0, 4.0]])
  2066. A = np.array([[3.9, 12.5, -34.5, 2.5],
  2067. [4.3, 21.5, -47.5, 7.5],
  2068. [4.3, 1.5, -43.5, 3.5],
  2069. [4.4, 6.0, -46.0, 6.0]])
  2070. B = np.array([[1.0, 1.0, -3.0, 1.0],
  2071. [1.0, 3.0, -5.0, 4.4],
  2072. [1.0, 2.0, -4.0, 1.0],
  2073. [1.2, 3.0, -4.0, 4.0]])
  2074. assert_raises(ValueError, qz, A, B, sort=lambda ar, ai, beta: ai == 0)
  2075. if False:
  2076. AA, BB, Q, Z, sdim = qz(A, B, sort=lambda ar, ai, beta: ai == 0)
  2077. # assert_(sdim == 2)
  2078. assert_(sdim == 4)
  2079. assert_array_almost_equal(Q @ AA @ Z.T, A)
  2080. assert_array_almost_equal(Q @ BB @ Z.T, B)
  2081. # test absolute values bc the sign is ambiguous and
  2082. # might be platform dependent
  2083. assert_array_almost_equal(np.abs(AA), np.abs(np.array(
  2084. [[35.7864, -80.9061, -12.0629, -9.498],
  2085. [0., 2.7638, -2.3505, 7.3256],
  2086. [0., 0., 0.6258, -0.0398],
  2087. [0., 0., 0., -12.8217]])), 4)
  2088. assert_array_almost_equal(np.abs(BB), np.abs(np.array(
  2089. [[4.5324, -8.7878, 3.2357, -3.5526],
  2090. [0., 1.4314, -2.1894, 0.9709],
  2091. [0., 0., 1.3126, -0.3468],
  2092. [0., 0., 0., 0.559]])), 4)
  2093. assert_array_almost_equal(np.abs(Q), np.abs(np.array(
  2094. [[-0.4193, -0.605, -0.1894, -0.6498],
  2095. [-0.5495, 0.6987, 0.2654, -0.3734],
  2096. [-0.4973, -0.3682, 0.6194, 0.4832],
  2097. [-0.5243, 0.1008, -0.7142, 0.4526]])), 4)
  2098. assert_array_almost_equal(np.abs(Z), np.abs(np.array(
  2099. [[-0.9471, -0.2971, -0.1217, 0.0055],
  2100. [-0.0367, 0.1209, 0.0358, 0.9913],
  2101. [0.3171, -0.9041, -0.2547, 0.1312],
  2102. [0.0346, 0.2824, -0.9587, 0.0014]])), 4)
  2103. # test absolute values bc the sign is ambiguous and might be platform
  2104. # dependent
  2105. # assert_array_almost_equal(abs(AA), abs(np.array([
  2106. # [3.8009, -69.4505, 50.3135, -43.2884],
  2107. # [0.0000, 9.2033, -0.2001, 5.9881],
  2108. # [0.0000, 0.0000, 1.4279, 4.4453],
  2109. # [0.0000, 0.0000, 0.9019, -1.1962]])), 4)
  2110. # assert_array_almost_equal(abs(BB), abs(np.array([
  2111. # [1.9005, -10.2285, 0.8658, -5.2134],
  2112. # [0.0000, 2.3008, 0.7915, 0.4262],
  2113. # [0.0000, 0.0000, 0.8101, 0.0000],
  2114. # [0.0000, 0.0000, 0.0000, -0.2823]])), 4)
  2115. # assert_array_almost_equal(abs(Q), abs(np.array([
  2116. # [0.4642, 0.7886, 0.2915, -0.2786],
  2117. # [0.5002, -0.5986, 0.5638, -0.2713],
  2118. # [0.5002, 0.0154, -0.0107, 0.8657],
  2119. # [0.5331, -0.1395, -0.7727, -0.3151]])), 4)
  2120. # assert_array_almost_equal(dot(Q,Q.T), eye(4))
  2121. # assert_array_almost_equal(abs(Z), abs(np.array([
  2122. # [0.9961, -0.0014, 0.0887, -0.0026],
  2123. # [0.0057, -0.0404, -0.0938, -0.9948],
  2124. # [0.0626, 0.7194, -0.6908, 0.0363],
  2125. # [0.0626, -0.6934, -0.7114, 0.0956]])), 4)
  2126. # assert_array_almost_equal(dot(Z,Z.T), eye(4))
  2127. # def test_qz_complex_sort(self):
  2128. # cA = np.array([
  2129. # [-21.10+22.50*1j, 53.50+-50.50*1j, -34.50+127.50*1j, 7.50+ 0.50*1j],
  2130. # [-0.46+ -7.78*1j, -3.50+-37.50*1j, -15.50+ 58.50*1j,-10.50+ -1.50*1j],
  2131. # [ 4.30+ -5.50*1j, 39.70+-17.10*1j, -68.50+ 12.50*1j, -7.50+ -3.50*1j],
  2132. # [ 5.50+ 4.40*1j, 14.40+ 43.30*1j, -32.50+-46.00*1j,-19.00+-32.50*1j]])
  2133. # cB = np.array([
  2134. # [1.00+ -5.00*1j, 1.60+ 1.20*1j,-3.00+ 0.00*1j, 0.00+ -1.00*1j],
  2135. # [0.80+ -0.60*1j, 3.00+ -5.00*1j,-4.00+ 3.00*1j,-2.40+ -3.20*1j],
  2136. # [1.00+ 0.00*1j, 2.40+ 1.80*1j,-4.00+ -5.00*1j, 0.00+ -3.00*1j],
  2137. # [0.00+ 1.00*1j,-1.80+ 2.40*1j, 0.00+ -4.00*1j, 4.00+ -5.00*1j]])
  2138. # AAS,BBS,QS,ZS,sdim = qz(cA,cB,sort='lhp')
  2139. # eigenvalues = diag(AAS)/diag(BBS)
  2140. # assert_(np.all(np.real(eigenvalues[:sdim] < 0)))
  2141. # assert_(np.all(np.real(eigenvalues[sdim:] > 0)))
  2142. def test_check_finite(self):
  2143. rng = np.random.RandomState(12345)
  2144. n = 5
  2145. A = rng.random([n, n])
  2146. B = rng.random([n, n])
  2147. AA, BB, Q, Z = qz(A, B, check_finite=False)
  2148. assert_array_almost_equal(Q @ AA @ Z.T, A)
  2149. assert_array_almost_equal(Q @ BB @ Z.T, B)
  2150. assert_array_almost_equal(Q @ Q.T, eye(n))
  2151. assert_array_almost_equal(Z @ Z.T, eye(n))
  2152. assert_(np.all(diag(BB) >= 0))
  2153. class TestOrdQZ:
  2154. @classmethod
  2155. def setup_class(cls):
  2156. # https://www.nag.com/lapack-ex/node119.html
  2157. A1 = np.array([[-21.10 - 22.50j, 53.5 - 50.5j, -34.5 + 127.5j,
  2158. 7.5 + 0.5j],
  2159. [-0.46 - 7.78j, -3.5 - 37.5j, -15.5 + 58.5j,
  2160. -10.5 - 1.5j],
  2161. [4.30 - 5.50j, 39.7 - 17.1j, -68.5 + 12.5j,
  2162. -7.5 - 3.5j],
  2163. [5.50 + 4.40j, 14.4 + 43.3j, -32.5 - 46.0j,
  2164. -19.0 - 32.5j]])
  2165. B1 = np.array([[1.0 - 5.0j, 1.6 + 1.2j, -3 + 0j, 0.0 - 1.0j],
  2166. [0.8 - 0.6j, .0 - 5.0j, -4 + 3j, -2.4 - 3.2j],
  2167. [1.0 + 0.0j, 2.4 + 1.8j, -4 - 5j, 0.0 - 3.0j],
  2168. [0.0 + 1.0j, -1.8 + 2.4j, 0 - 4j, 4.0 - 5.0j]])
  2169. # https://www.nag.com/numeric/fl/nagdoc_fl23/xhtml/F08/f08yuf.xml
  2170. A2 = np.array([[3.9, 12.5, -34.5, -0.5],
  2171. [4.3, 21.5, -47.5, 7.5],
  2172. [4.3, 21.5, -43.5, 3.5],
  2173. [4.4, 26.0, -46.0, 6.0]])
  2174. B2 = np.array([[1, 2, -3, 1],
  2175. [1, 3, -5, 4],
  2176. [1, 3, -4, 3],
  2177. [1, 3, -4, 4]])
  2178. # example with the eigenvalues
  2179. # -0.33891648, 1.61217396+0.74013521j, 1.61217396-0.74013521j,
  2180. # 0.61244091
  2181. # thus featuring:
  2182. # * one complex conjugate eigenvalue pair,
  2183. # * one eigenvalue in the lhp
  2184. # * 2 eigenvalues in the unit circle
  2185. # * 2 non-real eigenvalues
  2186. A3 = np.array([[5., 1., 3., 3.],
  2187. [4., 4., 2., 7.],
  2188. [7., 4., 1., 3.],
  2189. [0., 4., 8., 7.]])
  2190. B3 = np.array([[8., 10., 6., 10.],
  2191. [7., 7., 2., 9.],
  2192. [9., 1., 6., 6.],
  2193. [5., 1., 4., 7.]])
  2194. # example with infinite eigenvalues
  2195. A4 = np.eye(2)
  2196. B4 = np.diag([0, 1])
  2197. # example with (alpha, beta) = (0, 0)
  2198. A5 = np.diag([1, 0])
  2199. cls.A = [A1, A2, A3, A4, A5]
  2200. cls.B = [B1, B2, B3, B4, A5]
  2201. def qz_decomp(self, sort):
  2202. with np.errstate(all='raise'):
  2203. ret = [ordqz(Ai, Bi, sort=sort) for Ai, Bi in zip(self.A, self.B)]
  2204. return tuple(ret)
  2205. def check(self, A, B, sort, AA, BB, alpha, beta, Q, Z):
  2206. Id = np.eye(*A.shape)
  2207. # make sure Q and Z are orthogonal
  2208. assert_array_almost_equal(Q @ Q.T.conj(), Id)
  2209. assert_array_almost_equal(Z @ Z.T.conj(), Id)
  2210. # check factorization
  2211. assert_array_almost_equal(Q @ AA, A @ Z)
  2212. assert_array_almost_equal(Q @ BB, B @ Z)
  2213. # check shape of AA and BB
  2214. assert_array_equal(np.tril(AA, -2), np.zeros(AA.shape))
  2215. assert_array_equal(np.tril(BB, -1), np.zeros(BB.shape))
  2216. # check eigenvalues
  2217. for i in range(A.shape[0]):
  2218. # does the current diagonal element belong to a 2-by-2 block
  2219. # that was already checked?
  2220. if i > 0 and A[i, i - 1] != 0:
  2221. continue
  2222. # take care of 2-by-2 blocks
  2223. if i < AA.shape[0] - 1 and AA[i + 1, i] != 0:
  2224. evals, _ = eig(AA[i:i + 2, i:i + 2], BB[i:i + 2, i:i + 2])
  2225. # make sure the pair of complex conjugate eigenvalues
  2226. # is ordered consistently (positive imaginary part first)
  2227. if evals[0].imag < 0:
  2228. evals = evals[[1, 0]]
  2229. tmp = alpha[i:i + 2]/beta[i:i + 2]
  2230. if tmp[0].imag < 0:
  2231. tmp = tmp[[1, 0]]
  2232. assert_array_almost_equal(evals, tmp)
  2233. else:
  2234. if alpha[i] == 0 and beta[i] == 0:
  2235. assert_equal(AA[i, i], 0)
  2236. assert_equal(BB[i, i], 0)
  2237. elif beta[i] == 0:
  2238. assert_equal(BB[i, i], 0)
  2239. else:
  2240. assert_almost_equal(AA[i, i]/BB[i, i], alpha[i]/beta[i])
  2241. sortfun = _select_function(sort)
  2242. lastsort = True
  2243. for i in range(A.shape[0]):
  2244. cursort = sortfun(np.array([alpha[i]]), np.array([beta[i]]))
  2245. # once the sorting criterion was not matched all subsequent
  2246. # eigenvalues also shouldn't match
  2247. if not lastsort:
  2248. assert not cursort
  2249. lastsort = cursort
  2250. def check_all(self, sort):
  2251. ret = self.qz_decomp(sort)
  2252. for reti, Ai, Bi in zip(ret, self.A, self.B):
  2253. self.check(Ai, Bi, sort, *reti)
  2254. def test_lhp(self):
  2255. self.check_all('lhp')
  2256. def test_rhp(self):
  2257. self.check_all('rhp')
  2258. def test_iuc(self):
  2259. self.check_all('iuc')
  2260. def test_ouc(self):
  2261. self.check_all('ouc')
  2262. def test_ref(self):
  2263. # real eigenvalues first (top-left corner)
  2264. def sort(x, y):
  2265. out = np.empty_like(x, dtype=bool)
  2266. nonzero = (y != 0)
  2267. out[~nonzero] = False
  2268. out[nonzero] = (x[nonzero]/y[nonzero]).imag == 0
  2269. return out
  2270. self.check_all(sort)
  2271. def test_cef(self):
  2272. # complex eigenvalues first (top-left corner)
  2273. def sort(x, y):
  2274. out = np.empty_like(x, dtype=bool)
  2275. nonzero = (y != 0)
  2276. out[~nonzero] = False
  2277. out[nonzero] = (x[nonzero]/y[nonzero]).imag != 0
  2278. return out
  2279. self.check_all(sort)
  2280. def test_diff_input_types(self):
  2281. ret = ordqz(self.A[1], self.B[2], sort='lhp')
  2282. self.check(self.A[1], self.B[2], 'lhp', *ret)
  2283. ret = ordqz(self.B[2], self.A[1], sort='lhp')
  2284. self.check(self.B[2], self.A[1], 'lhp', *ret)
  2285. def test_sort_explicit(self):
  2286. # Test order of the eigenvalues in the 2 x 2 case where we can
  2287. # explicitly compute the solution
  2288. A1 = np.eye(2)
  2289. B1 = np.diag([-2, 0.5])
  2290. expected1 = [('lhp', [-0.5, 2]),
  2291. ('rhp', [2, -0.5]),
  2292. ('iuc', [-0.5, 2]),
  2293. ('ouc', [2, -0.5])]
  2294. A2 = np.eye(2)
  2295. B2 = np.diag([-2 + 1j, 0.5 + 0.5j])
  2296. expected2 = [('lhp', [1/(-2 + 1j), 1/(0.5 + 0.5j)]),
  2297. ('rhp', [1/(0.5 + 0.5j), 1/(-2 + 1j)]),
  2298. ('iuc', [1/(-2 + 1j), 1/(0.5 + 0.5j)]),
  2299. ('ouc', [1/(0.5 + 0.5j), 1/(-2 + 1j)])]
  2300. # 'lhp' is ambiguous so don't test it
  2301. A3 = np.eye(2)
  2302. B3 = np.diag([2, 0])
  2303. expected3 = [('rhp', [0.5, np.inf]),
  2304. ('iuc', [0.5, np.inf]),
  2305. ('ouc', [np.inf, 0.5])]
  2306. # 'rhp' is ambiguous so don't test it
  2307. A4 = np.eye(2)
  2308. B4 = np.diag([-2, 0])
  2309. expected4 = [('lhp', [-0.5, np.inf]),
  2310. ('iuc', [-0.5, np.inf]),
  2311. ('ouc', [np.inf, -0.5])]
  2312. A5 = np.diag([0, 1])
  2313. B5 = np.diag([0, 0.5])
  2314. # 'lhp' and 'iuc' are ambiguous so don't test them
  2315. expected5 = [('rhp', [2, np.nan]),
  2316. ('ouc', [2, np.nan])]
  2317. A = [A1, A2, A3, A4, A5]
  2318. B = [B1, B2, B3, B4, B5]
  2319. expected = [expected1, expected2, expected3, expected4, expected5]
  2320. for Ai, Bi, expectedi in zip(A, B, expected):
  2321. for sortstr, expected_eigvals in expectedi:
  2322. _, _, alpha, beta, _, _ = ordqz(Ai, Bi, sort=sortstr)
  2323. azero = (alpha == 0)
  2324. bzero = (beta == 0)
  2325. x = np.empty_like(alpha)
  2326. x[azero & bzero] = np.nan
  2327. x[~azero & bzero] = np.inf
  2328. x[~bzero] = alpha[~bzero]/beta[~bzero]
  2329. assert_allclose(expected_eigvals, x)
  2330. class TestOrdQZWorkspaceSize:
  2331. @pytest.mark.fail_slow(5)
  2332. def test_decompose(self):
  2333. rng = np.random.RandomState(12345)
  2334. N = 202
  2335. # raises error if lwork parameter to dtrsen is too small
  2336. for ddtype in [np.float32, np.float64]:
  2337. A = rng.random((N, N)).astype(ddtype)
  2338. B = rng.random((N, N)).astype(ddtype)
  2339. # sort = lambda ar, ai, b: ar**2 + ai**2 < b**2
  2340. _ = ordqz(A, B, sort=lambda alpha, beta: alpha < beta,
  2341. output='real')
  2342. for ddtype in [np.complex128, np.complex64]:
  2343. A = rng.random((N, N)).astype(ddtype)
  2344. B = rng.random((N, N)).astype(ddtype)
  2345. _ = ordqz(A, B, sort=lambda alpha, beta: alpha < beta,
  2346. output='complex')
  2347. @pytest.mark.slow
  2348. def test_decompose_ouc(self):
  2349. rng = np.random.RandomState(12345)
  2350. N = 202
  2351. # segfaults if lwork parameter to dtrsen is too small
  2352. for ddtype in [np.float32, np.float64, np.complex128, np.complex64]:
  2353. A = rng.random((N, N)).astype(ddtype)
  2354. B = rng.random((N, N)).astype(ddtype)
  2355. S, T, alpha, beta, U, V = ordqz(A, B, sort='ouc')
  2356. class TestDatacopied:
  2357. def test_datacopied(self):
  2358. from scipy.linalg._decomp import _datacopied
  2359. M = matrix([[0, 1], [2, 3]])
  2360. A = asarray(M)
  2361. L = M.tolist()
  2362. M2 = M.copy()
  2363. class Fake1:
  2364. def __array__(self, dtype=None, copy=None):
  2365. return A
  2366. class Fake2:
  2367. __array_interface__ = A.__array_interface__
  2368. F1 = Fake1()
  2369. F2 = Fake2()
  2370. for item, status in [(M, False), (A, False), (L, True),
  2371. (M2, False), (F1, False), (F2, False)]:
  2372. arr = asarray(item)
  2373. assert_equal(_datacopied(arr, item), status,
  2374. err_msg=repr(item))
  2375. def test_aligned_mem_float():
  2376. """Check linalg works with non-aligned memory (float32)"""
  2377. # Allocate 402 bytes of memory (allocated on boundary)
  2378. a = arange(402, dtype=np.uint8)
  2379. # Create an array with boundary offset 4
  2380. z = np.frombuffer(a.data, offset=2, count=100, dtype=float32)
  2381. z = z.reshape((10, 10))
  2382. eig(z, overwrite_a=True)
  2383. eig(z.T, overwrite_a=True)
  2384. @pytest.mark.skipif(platform.machine() == 'ppc64le',
  2385. reason="crashes on ppc64le")
  2386. def test_aligned_mem():
  2387. """Check linalg works with non-aligned memory (float64)"""
  2388. # Allocate 804 bytes of memory (allocated on boundary)
  2389. a = arange(804, dtype=np.uint8)
  2390. # Create an array with boundary offset 4
  2391. z = np.frombuffer(a.data, offset=4, count=100, dtype=float)
  2392. z = z.reshape((10, 10))
  2393. eig(z, overwrite_a=True)
  2394. eig(z.T, overwrite_a=True)
  2395. def test_aligned_mem_complex():
  2396. """Check that complex objects don't need to be completely aligned"""
  2397. # Allocate 1608 bytes of memory (allocated on boundary)
  2398. a = zeros(1608, dtype=np.uint8)
  2399. # Create an array with boundary offset 8
  2400. z = np.frombuffer(a.data, offset=8, count=100, dtype=complex)
  2401. z = z.reshape((10, 10))
  2402. eig(z, overwrite_a=True)
  2403. # This does not need special handling
  2404. eig(z.T, overwrite_a=True)
  2405. def check_lapack_misaligned(func, args, kwargs):
  2406. args = list(args)
  2407. for i in range(len(args)):
  2408. a = args[:]
  2409. if isinstance(a[i], np.ndarray):
  2410. # Try misaligning a[i]
  2411. aa = np.zeros(a[i].size*a[i].dtype.itemsize+8, dtype=np.uint8)
  2412. aa = np.frombuffer(aa.data, offset=4, count=a[i].size,
  2413. dtype=a[i].dtype)
  2414. aa = aa.reshape(a[i].shape)
  2415. aa[...] = a[i]
  2416. a[i] = aa
  2417. func(*a, **kwargs)
  2418. if len(a[i].shape) > 1:
  2419. a[i] = a[i].T
  2420. func(*a, **kwargs)
  2421. @pytest.mark.xfail(run=False,
  2422. reason="Ticket #1152, triggers a segfault in rare cases.")
  2423. def test_lapack_misaligned():
  2424. M = np.eye(10, dtype=float)
  2425. R = np.arange(100).reshape((10, 10))
  2426. S = np.arange(20000, dtype=np.uint8)
  2427. S = np.frombuffer(S.data, offset=4, count=100, dtype=float)
  2428. S = S.reshape((10, 10))
  2429. b = np.ones(10)
  2430. LU, piv = lu_factor(S)
  2431. for (func, args, kwargs) in [
  2432. (eig, (S,), dict(overwrite_a=True)), # crash
  2433. (eigvals, (S,), dict(overwrite_a=True)), # no crash
  2434. (lu, (S,), dict(overwrite_a=True)), # no crash
  2435. (lu_factor, (S,), dict(overwrite_a=True)), # no crash
  2436. (lu_solve, ((LU, piv), b), dict(overwrite_b=True)),
  2437. (solve, (S, b), dict(overwrite_a=True, overwrite_b=True)),
  2438. (svd, (M,), dict(overwrite_a=True)), # no crash
  2439. (svd, (R,), dict(overwrite_a=True)), # no crash
  2440. (svd, (S,), dict(overwrite_a=True)), # crash
  2441. (svdvals, (S,), dict()), # no crash
  2442. (svdvals, (S,), dict(overwrite_a=True)), # crash
  2443. (cholesky, (M,), dict(overwrite_a=True)), # no crash
  2444. (qr, (S,), dict(overwrite_a=True)), # crash
  2445. (rq, (S,), dict(overwrite_a=True)), # crash
  2446. (hessenberg, (S,), dict(overwrite_a=True)), # crash
  2447. (schur, (S,), dict(overwrite_a=True)), # crash
  2448. ]:
  2449. check_lapack_misaligned(func, args, kwargs)
  2450. # not properly tested
  2451. # cholesky, rsf2csf, lu_solve, solve, eig_banded, eigvals_banded, eigh, diagsvd
  2452. class TestOverwrite:
  2453. def test_eig(self):
  2454. assert_no_overwrite(eig, [(3, 3)])
  2455. assert_no_overwrite(eig, [(3, 3), (3, 3)])
  2456. def test_eigh(self):
  2457. assert_no_overwrite(eigh, [(3, 3)])
  2458. assert_no_overwrite(eigh, [(3, 3), (3, 3)])
  2459. def test_eig_banded(self):
  2460. assert_no_overwrite(eig_banded, [(3, 2)])
  2461. def test_eigvals(self):
  2462. assert_no_overwrite(eigvals, [(3, 3)])
  2463. def test_eigvalsh(self):
  2464. assert_no_overwrite(eigvalsh, [(3, 3)])
  2465. def test_eigvals_banded(self):
  2466. assert_no_overwrite(eigvals_banded, [(3, 2)])
  2467. def test_hessenberg(self):
  2468. assert_no_overwrite(hessenberg, [(3, 3)])
  2469. def test_lu_factor(self):
  2470. assert_no_overwrite(lu_factor, [(3, 3)])
  2471. def test_lu_solve(self):
  2472. x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 8]])
  2473. xlu = lu_factor(x)
  2474. assert_no_overwrite(lambda b: lu_solve(xlu, b), [(3,)])
  2475. def test_lu(self):
  2476. assert_no_overwrite(lu, [(3, 3)])
  2477. def test_qr(self):
  2478. assert_no_overwrite(qr, [(3, 3)])
  2479. def test_rq(self):
  2480. assert_no_overwrite(rq, [(3, 3)])
  2481. def test_schur(self):
  2482. assert_no_overwrite(schur, [(3, 3)])
  2483. def test_schur_complex(self):
  2484. assert_no_overwrite(lambda a: schur(a, 'complex'), [(3, 3)],
  2485. dtypes=[np.float32, np.float64])
  2486. def test_svd(self):
  2487. assert_no_overwrite(svd, [(3, 3)])
  2488. assert_no_overwrite(lambda a: svd(a, lapack_driver='gesvd'), [(3, 3)])
  2489. def test_svdvals(self):
  2490. assert_no_overwrite(svdvals, [(3, 3)])
  2491. def _check_orth(n, dtype, skip_big=False):
  2492. X = np.ones((n, 2), dtype=float).astype(dtype)
  2493. eps = np.finfo(dtype).eps
  2494. tol = 1000 * eps
  2495. Y = orth(X)
  2496. assert_equal(Y.shape, (n, 1))
  2497. assert_allclose(Y, Y.mean(), atol=tol, rtol=1.4e-7)
  2498. Y = orth(X.T)
  2499. assert_equal(Y.shape, (2, 1))
  2500. assert_allclose(Y, Y.mean(), atol=tol)
  2501. if n > 5 and not skip_big:
  2502. rng = np.random.RandomState(1)
  2503. X = rng.rand(n, 5) @ rng.rand(5, n)
  2504. X = X + 1e-4 * rng.rand(n, 1) @ rng.rand(1, n)
  2505. X = X.astype(dtype)
  2506. Y = orth(X, rcond=1e-3)
  2507. assert_equal(Y.shape, (n, 5))
  2508. Y = orth(X, rcond=1e-6)
  2509. assert_equal(Y.shape, (n, 5 + 1))
  2510. @pytest.mark.slow
  2511. @pytest.mark.skipif(np.dtype(np.intp).itemsize < 8,
  2512. reason="test only on 64-bit, else too slow")
  2513. def test_orth_memory_efficiency():
  2514. # Pick n so that 16*n bytes is reasonable but 8*n*n bytes is unreasonable.
  2515. # Keep in mind that @pytest.mark.slow tests are likely to be running
  2516. # under configurations that support 4Gb+ memory for tests related to
  2517. # 32 bit overflow.
  2518. n = 10*1000*1000
  2519. try:
  2520. _check_orth(n, np.float64, skip_big=True)
  2521. except MemoryError as e:
  2522. raise AssertionError(
  2523. 'memory error perhaps caused by orth regression'
  2524. ) from e
  2525. def test_orth():
  2526. dtypes = [np.float32, np.float64, np.complex64, np.complex128]
  2527. sizes = [1, 2, 3, 10, 100]
  2528. for dt, n in itertools.product(dtypes, sizes):
  2529. _check_orth(n, dt)
  2530. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  2531. def test_orth_empty(dt):
  2532. a = np.empty((0, 0), dtype=dt)
  2533. a0 = np.eye(2, dtype=dt)
  2534. oa = orth(a)
  2535. assert oa.dtype == orth(a0).dtype
  2536. assert oa.shape == (0, 0)
  2537. class TestNullSpace:
  2538. def test_null_space(self):
  2539. rng = np.random.RandomState(1)
  2540. dtypes = [np.float32, np.float64, np.complex64, np.complex128]
  2541. sizes = [1, 2, 3, 10, 100]
  2542. for dt, n in itertools.product(dtypes, sizes):
  2543. X = np.ones((2, n), dtype=dt)
  2544. eps = np.finfo(dt).eps
  2545. tol = 1000 * eps
  2546. Y = null_space(X)
  2547. assert_equal(Y.shape, (n, n-1))
  2548. assert_allclose(X @ Y, 0, atol=tol)
  2549. Y = null_space(X.T)
  2550. assert_equal(Y.shape, (2, 1))
  2551. assert_allclose(X.T @ Y, 0, atol=tol)
  2552. X = rng.randn(1 + n//2, n)
  2553. Y = null_space(X)
  2554. assert_equal(Y.shape, (n, n - 1 - n//2))
  2555. assert_allclose(X @ Y, 0, atol=tol)
  2556. if n > 5:
  2557. rng = np.random.RandomState(1)
  2558. X = rng.rand(n, 5) @ rng.rand(5, n)
  2559. X = X + 1e-4 * rng.rand(n, 1) @ rng.rand(1, n)
  2560. X = X.astype(dt)
  2561. Y = null_space(X, rcond=1e-3)
  2562. assert_equal(Y.shape, (n, n - 5))
  2563. Y = null_space(X, rcond=1e-6)
  2564. assert_equal(Y.shape, (n, n - 6))
  2565. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  2566. def test_null_space_empty(self, dt):
  2567. a = np.empty((0, 0), dtype=dt)
  2568. a0 = np.eye(2, dtype=dt)
  2569. nsa = null_space(a)
  2570. assert nsa.shape == (0, 0)
  2571. assert nsa.dtype == null_space(a0).dtype
  2572. @pytest.mark.parametrize("overwrite_a", [True, False])
  2573. @pytest.mark.parametrize("check_finite", [True, False])
  2574. @pytest.mark.parametrize("lapack_driver", ["gesdd", "gesvd"])
  2575. def test_null_space_options(self, overwrite_a, check_finite, lapack_driver):
  2576. rng = np.random.default_rng(42887289350573064398746)
  2577. n = 10
  2578. X = rng.standard_normal((1 + n//2, n))
  2579. Y = null_space(X.copy(), overwrite_a=overwrite_a, check_finite=check_finite,
  2580. lapack_driver=lapack_driver)
  2581. assert_allclose(X @ Y, 0, atol=np.finfo(X.dtype).eps*100)
  2582. def test_subspace_angles():
  2583. H = hadamard(8, float)
  2584. A = H[:, :3]
  2585. B = H[:, 3:]
  2586. assert_allclose(subspace_angles(A, B), [np.pi / 2.] * 3, atol=1e-14)
  2587. assert_allclose(subspace_angles(B, A), [np.pi / 2.] * 3, atol=1e-14)
  2588. for x in (A, B):
  2589. assert_allclose(subspace_angles(x, x), np.zeros(x.shape[1]),
  2590. atol=1e-14)
  2591. # From MATLAB function "subspace", which effectively only returns the
  2592. # last value that we calculate
  2593. x = np.array(
  2594. [[0.537667139546100, 0.318765239858981, 3.578396939725760, 0.725404224946106], # noqa: E501
  2595. [1.833885014595086, -1.307688296305273, 2.769437029884877, -0.063054873189656], # noqa: E501
  2596. [-2.258846861003648, -0.433592022305684, -1.349886940156521, 0.714742903826096], # noqa: E501
  2597. [0.862173320368121, 0.342624466538650, 3.034923466331855, -0.204966058299775]]) # noqa: E501
  2598. expected = 1.481454682101605
  2599. assert_allclose(subspace_angles(x[:, :2], x[:, 2:])[0], expected,
  2600. rtol=1e-12)
  2601. assert_allclose(subspace_angles(x[:, 2:], x[:, :2])[0], expected,
  2602. rtol=1e-12)
  2603. expected = 0.746361174247302
  2604. assert_allclose(subspace_angles(x[:, :2], x[:, [2]]), expected, rtol=1e-12)
  2605. assert_allclose(subspace_angles(x[:, [2]], x[:, :2]), expected, rtol=1e-12)
  2606. expected = 0.487163718534313
  2607. assert_allclose(subspace_angles(x[:, :3], x[:, [3]]), expected, rtol=1e-12)
  2608. assert_allclose(subspace_angles(x[:, [3]], x[:, :3]), expected, rtol=1e-12)
  2609. expected = 0.328950515907756
  2610. assert_allclose(subspace_angles(x[:, :2], x[:, 1:]), [expected, 0],
  2611. atol=1e-12)
  2612. # Degenerate conditions
  2613. assert_raises(ValueError, subspace_angles, x[0], x)
  2614. assert_raises(ValueError, subspace_angles, x, x[0])
  2615. assert_raises(ValueError, subspace_angles, x[:-1], x)
  2616. # Test branch if mask.any is True:
  2617. A = np.array([[1, 0, 0],
  2618. [0, 1, 0],
  2619. [0, 0, 1],
  2620. [0, 0, 0],
  2621. [0, 0, 0]])
  2622. B = np.array([[1, 0, 0],
  2623. [0, 1, 0],
  2624. [0, 0, 0],
  2625. [0, 0, 0],
  2626. [0, 0, 1]])
  2627. expected = np.array([np.pi/2, 0, 0])
  2628. assert_allclose(subspace_angles(A, B), expected, rtol=1e-12)
  2629. # Complex
  2630. # second column in "b" does not affect result, just there so that
  2631. # b can have more cols than a, and vice-versa (both conditional code paths)
  2632. a = [[1 + 1j], [0]]
  2633. b = [[1 - 1j, 0], [0, 1]]
  2634. assert_allclose(subspace_angles(a, b), 0., atol=1e-14)
  2635. assert_allclose(subspace_angles(b, a), 0., atol=1e-14)
  2636. # Empty
  2637. a = np.empty((0, 0))
  2638. b = np.empty((0, 0))
  2639. assert_allclose(subspace_angles(a, b), np.empty((0,)))
  2640. a = np.empty((2, 0))
  2641. b = np.empty((2, 0))
  2642. assert_allclose(subspace_angles(a, b), np.empty((0,)))
  2643. a = np.empty((0, 2))
  2644. b = np.empty((0, 3))
  2645. assert_allclose(subspace_angles(a, b), np.empty((0,)))
  2646. class TestCDF2RDF:
  2647. def matmul(self, a, b):
  2648. return np.einsum('...ij,...jk->...ik', a, b)
  2649. def assert_eig_valid(self, w, v, x):
  2650. assert_array_almost_equal(
  2651. self.matmul(v, w),
  2652. self.matmul(x, v)
  2653. )
  2654. def test_single_array0x0real(self):
  2655. # eig doesn't support 0x0 in old versions of numpy
  2656. X = np.empty((0, 0))
  2657. w, v = np.empty(0), np.empty((0, 0))
  2658. wr, vr = cdf2rdf(w, v)
  2659. self.assert_eig_valid(wr, vr, X)
  2660. def test_single_array2x2_real(self):
  2661. X = np.array([[1, 2], [3, -1]])
  2662. w, v = np.linalg.eig(X)
  2663. wr, vr = cdf2rdf(w, v)
  2664. self.assert_eig_valid(wr, vr, X)
  2665. def test_single_array2x2_complex(self):
  2666. X = np.array([[1, 2], [-2, 1]])
  2667. w, v = np.linalg.eig(X)
  2668. wr, vr = cdf2rdf(w, v)
  2669. self.assert_eig_valid(wr, vr, X)
  2670. def test_single_array3x3_real(self):
  2671. X = np.array([[1, 2, 3], [1, 2, 3], [2, 5, 6]])
  2672. w, v = np.linalg.eig(X)
  2673. wr, vr = cdf2rdf(w, v)
  2674. self.assert_eig_valid(wr, vr, X)
  2675. def test_single_array3x3_complex(self):
  2676. X = np.array([[1, 2, 3], [0, 4, 5], [0, -5, 4]])
  2677. w, v = np.linalg.eig(X)
  2678. wr, vr = cdf2rdf(w, v)
  2679. self.assert_eig_valid(wr, vr, X)
  2680. def test_random_1d_stacked_arrays(self):
  2681. rng = np.random.default_rng(1234)
  2682. # cannot test M == 0 due to bug in old numpy
  2683. for M in range(1, 7):
  2684. X = rng.random((100, M, M))
  2685. w, v = np.linalg.eig(X)
  2686. wr, vr = cdf2rdf(w, v)
  2687. self.assert_eig_valid(wr, vr, X)
  2688. def test_random_2d_stacked_arrays(self):
  2689. rng = np.random.default_rng(1234)
  2690. # cannot test M == 0 due to bug in old numpy
  2691. for M in range(1, 7):
  2692. X = rng.random((10, 10, M, M))
  2693. w, v = np.linalg.eig(X)
  2694. wr, vr = cdf2rdf(w, v)
  2695. self.assert_eig_valid(wr, vr, X)
  2696. def test_low_dimensionality_error(self):
  2697. w, v = np.empty(()), np.array((2,))
  2698. assert_raises(ValueError, cdf2rdf, w, v)
  2699. def test_not_square_error(self):
  2700. # Check that passing a non-square array raises a ValueError.
  2701. w, v = np.arange(3), np.arange(6).reshape(3, 2)
  2702. assert_raises(ValueError, cdf2rdf, w, v)
  2703. def test_swapped_v_w_error(self):
  2704. # Check that exchanging places of w and v raises ValueError.
  2705. X = np.array([[1, 2, 3], [0, 4, 5], [0, -5, 4]])
  2706. w, v = np.linalg.eig(X)
  2707. assert_raises(ValueError, cdf2rdf, v, w)
  2708. def test_non_associated_error(self):
  2709. # Check that passing non-associated eigenvectors raises a ValueError.
  2710. w, v = np.arange(3), np.arange(16).reshape(4, 4)
  2711. assert_raises(ValueError, cdf2rdf, w, v)
  2712. def test_not_conjugate_pairs(self):
  2713. # Check that passing non-conjugate pairs raises a ValueError.
  2714. X = np.array([[1, 2, 3], [1, 2, 3], [2, 5, 6+1j]])
  2715. w, v = np.linalg.eig(X)
  2716. assert_raises(ValueError, cdf2rdf, w, v)
  2717. # different arrays in the stack, so not conjugate
  2718. X = np.array([
  2719. [[1, 2, 3], [1, 2, 3], [2, 5, 6+1j]],
  2720. [[1, 2, 3], [1, 2, 3], [2, 5, 6-1j]],
  2721. ])
  2722. w, v = np.linalg.eig(X)
  2723. assert_raises(ValueError, cdf2rdf, w, v)