test_basic.py 103 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767
  1. import os
  2. import platform
  3. import itertools
  4. import warnings
  5. import numpy as np
  6. from numpy import (arange, array, dot, zeros, identity, conjugate, transpose,
  7. float32)
  8. from numpy.testing import (assert_equal, assert_almost_equal, assert_,
  9. assert_array_almost_equal, assert_allclose,
  10. assert_array_equal)
  11. import pytest
  12. from pytest import raises as assert_raises
  13. from scipy.linalg import (solve, inv, det, lstsq, pinv, pinvh, norm,
  14. solve_banded, solveh_banded, solve_triangular,
  15. solve_circulant, circulant, LinAlgError, block_diag,
  16. matrix_balance, qr, LinAlgWarning)
  17. from scipy.linalg._testutils import assert_no_overwrite
  18. from scipy._lib._testutils import check_free_memory, IS_MUSL
  19. from scipy.linalg.blas import HAS_ILP64
  20. from scipy.conftest import skip_xp_invalid_arg
  21. REAL_DTYPES = (np.float32, np.float64, np.longdouble)
  22. COMPLEX_DTYPES = (np.complex64, np.complex128, np.clongdouble)
  23. DTYPES = REAL_DTYPES + COMPLEX_DTYPES
  24. parametrize_overwrite_arg = pytest.mark.parametrize(
  25. "overwrite_kw", [{"overwrite_a": True}, {"overwrite_a": False}, {}],
  26. ids=["True", "False", "None"]
  27. )
  28. parametrize_overwrite_b_arg = pytest.mark.parametrize(
  29. "overwrite_b_kw", [{"overwrite_b": True}, {"overwrite_b": False}, {}],
  30. ids=["True", "False", "None"]
  31. )
  32. def _eps_cast(dtyp):
  33. """Get the epsilon for dtype, possibly downcast to BLAS types."""
  34. dt = dtyp
  35. if dt == np.longdouble:
  36. dt = np.float64
  37. elif dt == np.clongdouble:
  38. dt = np.complex128
  39. return np.finfo(dt).eps
  40. class TestSolveBanded:
  41. def test_real(self):
  42. a = array([[1.0, 20, 0, 0],
  43. [-30, 4, 6, 0],
  44. [2, 1, 20, 2],
  45. [0, -1, 7, 14]])
  46. ab = array([[0.0, 20, 6, 2],
  47. [1, 4, 20, 14],
  48. [-30, 1, 7, 0],
  49. [2, -1, 0, 0]])
  50. l, u = 2, 1
  51. b4 = array([10.0, 0.0, 2.0, 14.0])
  52. b4by1 = b4.reshape(-1, 1)
  53. b4by2 = array([[2, 1],
  54. [-30, 4],
  55. [2, 3],
  56. [1, 3]])
  57. b4by4 = array([[1, 0, 0, 0],
  58. [0, 0, 0, 1],
  59. [0, 1, 0, 0],
  60. [0, 1, 0, 0]])
  61. for b in [b4, b4by1, b4by2, b4by4]:
  62. x = solve_banded((l, u), ab, b)
  63. assert_array_almost_equal(dot(a, x), b)
  64. def test_complex(self):
  65. a = array([[1.0, 20, 0, 0],
  66. [-30, 4, 6, 0],
  67. [2j, 1, 20, 2j],
  68. [0, -1, 7, 14]])
  69. ab = array([[0.0, 20, 6, 2j],
  70. [1, 4, 20, 14],
  71. [-30, 1, 7, 0],
  72. [2j, -1, 0, 0]])
  73. l, u = 2, 1
  74. b4 = array([10.0, 0.0, 2.0, 14.0j])
  75. b4by1 = b4.reshape(-1, 1)
  76. b4by2 = array([[2, 1],
  77. [-30, 4],
  78. [2, 3],
  79. [1, 3]])
  80. b4by4 = array([[1, 0, 0, 0],
  81. [0, 0, 0, 1j],
  82. [0, 1, 0, 0],
  83. [0, 1, 0, 0]])
  84. for b in [b4, b4by1, b4by2, b4by4]:
  85. x = solve_banded((l, u), ab, b)
  86. assert_array_almost_equal(dot(a, x), b)
  87. def test_tridiag_real(self):
  88. ab = array([[0.0, 20, 6, 2],
  89. [1, 4, 20, 14],
  90. [-30, 1, 7, 0]])
  91. a = np.diag(ab[0, 1:], 1) + np.diag(ab[1, :], 0) + np.diag(
  92. ab[2, :-1], -1)
  93. b4 = array([10.0, 0.0, 2.0, 14.0])
  94. b4by1 = b4.reshape(-1, 1)
  95. b4by2 = array([[2, 1],
  96. [-30, 4],
  97. [2, 3],
  98. [1, 3]])
  99. b4by4 = array([[1, 0, 0, 0],
  100. [0, 0, 0, 1],
  101. [0, 1, 0, 0],
  102. [0, 1, 0, 0]])
  103. for b in [b4, b4by1, b4by2, b4by4]:
  104. x = solve_banded((1, 1), ab, b)
  105. assert_array_almost_equal(dot(a, x), b)
  106. def test_tridiag_complex(self):
  107. ab = array([[0.0, 20, 6, 2j],
  108. [1, 4, 20, 14],
  109. [-30, 1, 7, 0]])
  110. a = np.diag(ab[0, 1:], 1) + np.diag(ab[1, :], 0) + np.diag(
  111. ab[2, :-1], -1)
  112. b4 = array([10.0, 0.0, 2.0, 14.0j])
  113. b4by1 = b4.reshape(-1, 1)
  114. b4by2 = array([[2, 1],
  115. [-30, 4],
  116. [2, 3],
  117. [1, 3]])
  118. b4by4 = array([[1, 0, 0, 0],
  119. [0, 0, 0, 1],
  120. [0, 1, 0, 0],
  121. [0, 1, 0, 0]])
  122. for b in [b4, b4by1, b4by2, b4by4]:
  123. x = solve_banded((1, 1), ab, b)
  124. assert_array_almost_equal(dot(a, x), b)
  125. def test_check_finite(self):
  126. a = array([[1.0, 20, 0, 0],
  127. [-30, 4, 6, 0],
  128. [2, 1, 20, 2],
  129. [0, -1, 7, 14]])
  130. ab = array([[0.0, 20, 6, 2],
  131. [1, 4, 20, 14],
  132. [-30, 1, 7, 0],
  133. [2, -1, 0, 0]])
  134. l, u = 2, 1
  135. b4 = array([10.0, 0.0, 2.0, 14.0])
  136. x = solve_banded((l, u), ab, b4, check_finite=False)
  137. assert_array_almost_equal(dot(a, x), b4)
  138. def test_bad_shape(self):
  139. ab = array([[0.0, 20, 6, 2],
  140. [1, 4, 20, 14],
  141. [-30, 1, 7, 0],
  142. [2, -1, 0, 0]])
  143. l, u = 2, 1
  144. bad = array([1.0, 2.0, 3.0, 4.0]).reshape(-1, 4)
  145. assert_raises(ValueError, solve_banded, (l, u), ab, bad)
  146. assert_raises(ValueError, solve_banded, (l, u), ab, [1.0, 2.0])
  147. # Values of (l,u) are not compatible with ab.
  148. assert_raises(ValueError, solve_banded, (1, 1), ab, [1.0, 2.0])
  149. def test_1x1(self):
  150. # gh-8906 noted that the case of A@x = b with 1x1 A was handled
  151. # incorrectly; check that this is resolved. Typical case:
  152. # nupper == nlower == 0
  153. # A = [[2]]
  154. b = array([[1., 2., 3.]])
  155. ref = array([[0.5, 1.0, 1.5]])
  156. x = solve_banded((0, 0), [[2]], b)
  157. assert_allclose(x, ref, rtol=1e-15)
  158. # However, the user *can* represent the same system with garbage rows
  159. # in `ab`. Test the case with `nupper == 1, nlower == 1`.
  160. x = solve_banded((1, 1), [[0], [2], [0]], b)
  161. assert_allclose(x, ref, rtol=1e-15)
  162. assert_equal(x.dtype, np.dtype('f8'))
  163. assert_array_equal(b, [[1.0, 2.0, 3.0]])
  164. def test_native_list_arguments(self):
  165. a = [[1.0, 20, 0, 0],
  166. [-30, 4, 6, 0],
  167. [2, 1, 20, 2],
  168. [0, -1, 7, 14]]
  169. ab = [[0.0, 20, 6, 2],
  170. [1, 4, 20, 14],
  171. [-30, 1, 7, 0],
  172. [2, -1, 0, 0]]
  173. l, u = 2, 1
  174. b = [10.0, 0.0, 2.0, 14.0]
  175. x = solve_banded((l, u), ab, b)
  176. assert_array_almost_equal(dot(a, x), b)
  177. @pytest.mark.parametrize('dt_ab', [int, float, np.float32, complex, np.complex64])
  178. @pytest.mark.parametrize('dt_b', [int, float, np.float32, complex, np.complex64])
  179. def test_empty(self, dt_ab, dt_b):
  180. # ab contains one empty row corresponding to the diagonal
  181. ab = np.array([[]], dtype=dt_ab)
  182. b = np.array([], dtype=dt_b)
  183. x = solve_banded((0, 0), ab, b)
  184. assert x.shape == (0,)
  185. assert x.dtype == solve(np.eye(1, dtype=dt_ab), np.ones(1, dtype=dt_b)).dtype
  186. b = np.empty((0, 0), dtype=dt_b)
  187. x = solve_banded((0, 0), ab, b)
  188. assert x.shape == (0, 0)
  189. assert x.dtype == solve(np.eye(1, dtype=dt_ab), np.ones(1, dtype=dt_b)).dtype
  190. class TestSolveHBanded:
  191. def test_01_upper(self):
  192. # Solve
  193. # [ 4 1 2 0] [1]
  194. # [ 1 4 1 2] X = [4]
  195. # [ 2 1 4 1] [1]
  196. # [ 0 2 1 4] [2]
  197. # with the RHS as a 1D array.
  198. ab = array([[0.0, 0.0, 2.0, 2.0],
  199. [-99, 1.0, 1.0, 1.0],
  200. [4.0, 4.0, 4.0, 4.0]])
  201. b = array([1.0, 4.0, 1.0, 2.0])
  202. x = solveh_banded(ab, b)
  203. assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0])
  204. def test_02_upper(self):
  205. # Solve
  206. # [ 4 1 2 0] [1 6]
  207. # [ 1 4 1 2] X = [4 2]
  208. # [ 2 1 4 1] [1 6]
  209. # [ 0 2 1 4] [2 1]
  210. #
  211. ab = array([[0.0, 0.0, 2.0, 2.0],
  212. [-99, 1.0, 1.0, 1.0],
  213. [4.0, 4.0, 4.0, 4.0]])
  214. b = array([[1.0, 6.0],
  215. [4.0, 2.0],
  216. [1.0, 6.0],
  217. [2.0, 1.0]])
  218. x = solveh_banded(ab, b)
  219. expected = array([[0.0, 1.0],
  220. [1.0, 0.0],
  221. [0.0, 1.0],
  222. [0.0, 0.0]])
  223. assert_array_almost_equal(x, expected)
  224. def test_03_upper(self):
  225. # Solve
  226. # [ 4 1 2 0] [1]
  227. # [ 1 4 1 2] X = [4]
  228. # [ 2 1 4 1] [1]
  229. # [ 0 2 1 4] [2]
  230. # with the RHS as a 2D array with shape (3,1).
  231. ab = array([[0.0, 0.0, 2.0, 2.0],
  232. [-99, 1.0, 1.0, 1.0],
  233. [4.0, 4.0, 4.0, 4.0]])
  234. b = array([1.0, 4.0, 1.0, 2.0]).reshape(-1, 1)
  235. x = solveh_banded(ab, b)
  236. assert_array_almost_equal(x, array([0., 1., 0., 0.]).reshape(-1, 1))
  237. def test_01_lower(self):
  238. # Solve
  239. # [ 4 1 2 0] [1]
  240. # [ 1 4 1 2] X = [4]
  241. # [ 2 1 4 1] [1]
  242. # [ 0 2 1 4] [2]
  243. #
  244. ab = array([[4.0, 4.0, 4.0, 4.0],
  245. [1.0, 1.0, 1.0, -99],
  246. [2.0, 2.0, 0.0, 0.0]])
  247. b = array([1.0, 4.0, 1.0, 2.0])
  248. x = solveh_banded(ab, b, lower=True)
  249. assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0])
  250. def test_02_lower(self):
  251. # Solve
  252. # [ 4 1 2 0] [1 6]
  253. # [ 1 4 1 2] X = [4 2]
  254. # [ 2 1 4 1] [1 6]
  255. # [ 0 2 1 4] [2 1]
  256. #
  257. ab = array([[4.0, 4.0, 4.0, 4.0],
  258. [1.0, 1.0, 1.0, -99],
  259. [2.0, 2.0, 0.0, 0.0]])
  260. b = array([[1.0, 6.0],
  261. [4.0, 2.0],
  262. [1.0, 6.0],
  263. [2.0, 1.0]])
  264. x = solveh_banded(ab, b, lower=True)
  265. expected = array([[0.0, 1.0],
  266. [1.0, 0.0],
  267. [0.0, 1.0],
  268. [0.0, 0.0]])
  269. assert_array_almost_equal(x, expected)
  270. def test_01_float32(self):
  271. # Solve
  272. # [ 4 1 2 0] [1]
  273. # [ 1 4 1 2] X = [4]
  274. # [ 2 1 4 1] [1]
  275. # [ 0 2 1 4] [2]
  276. #
  277. ab = array([[0.0, 0.0, 2.0, 2.0],
  278. [-99, 1.0, 1.0, 1.0],
  279. [4.0, 4.0, 4.0, 4.0]], dtype=float32)
  280. b = array([1.0, 4.0, 1.0, 2.0], dtype=float32)
  281. x = solveh_banded(ab, b)
  282. assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0])
  283. def test_02_float32(self):
  284. # Solve
  285. # [ 4 1 2 0] [1 6]
  286. # [ 1 4 1 2] X = [4 2]
  287. # [ 2 1 4 1] [1 6]
  288. # [ 0 2 1 4] [2 1]
  289. #
  290. ab = array([[0.0, 0.0, 2.0, 2.0],
  291. [-99, 1.0, 1.0, 1.0],
  292. [4.0, 4.0, 4.0, 4.0]], dtype=float32)
  293. b = array([[1.0, 6.0],
  294. [4.0, 2.0],
  295. [1.0, 6.0],
  296. [2.0, 1.0]], dtype=float32)
  297. x = solveh_banded(ab, b)
  298. expected = array([[0.0, 1.0],
  299. [1.0, 0.0],
  300. [0.0, 1.0],
  301. [0.0, 0.0]])
  302. assert_array_almost_equal(x, expected)
  303. def test_01_complex(self):
  304. # Solve
  305. # [ 4 -j 2 0] [2-j]
  306. # [ j 4 -j 2] X = [4-j]
  307. # [ 2 j 4 -j] [4+j]
  308. # [ 0 2 j 4] [2+j]
  309. #
  310. ab = array([[0.0, 0.0, 2.0, 2.0],
  311. [-99, -1.0j, -1.0j, -1.0j],
  312. [4.0, 4.0, 4.0, 4.0]])
  313. b = array([2-1.0j, 4.0-1j, 4+1j, 2+1j])
  314. x = solveh_banded(ab, b)
  315. assert_array_almost_equal(x, [0.0, 1.0, 1.0, 0.0])
  316. def test_02_complex(self):
  317. # Solve
  318. # [ 4 -j 2 0] [2-j 2+4j]
  319. # [ j 4 -j 2] X = [4-j -1-j]
  320. # [ 2 j 4 -j] [4+j 4+2j]
  321. # [ 0 2 j 4] [2+j j]
  322. #
  323. ab = array([[0.0, 0.0, 2.0, 2.0],
  324. [-99, -1.0j, -1.0j, -1.0j],
  325. [4.0, 4.0, 4.0, 4.0]])
  326. b = array([[2-1j, 2+4j],
  327. [4.0-1j, -1-1j],
  328. [4.0+1j, 4+2j],
  329. [2+1j, 1j]])
  330. x = solveh_banded(ab, b)
  331. expected = array([[0.0, 1.0j],
  332. [1.0, 0.0],
  333. [1.0, 1.0],
  334. [0.0, 0.0]])
  335. assert_array_almost_equal(x, expected)
  336. def test_tridiag_01_upper(self):
  337. # Solve
  338. # [ 4 1 0] [1]
  339. # [ 1 4 1] X = [4]
  340. # [ 0 1 4] [1]
  341. # with the RHS as a 1D array.
  342. ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]])
  343. b = array([1.0, 4.0, 1.0])
  344. x = solveh_banded(ab, b)
  345. assert_array_almost_equal(x, [0.0, 1.0, 0.0])
  346. def test_tridiag_02_upper(self):
  347. # Solve
  348. # [ 4 1 0] [1 4]
  349. # [ 1 4 1] X = [4 2]
  350. # [ 0 1 4] [1 4]
  351. #
  352. ab = array([[-99, 1.0, 1.0],
  353. [4.0, 4.0, 4.0]])
  354. b = array([[1.0, 4.0],
  355. [4.0, 2.0],
  356. [1.0, 4.0]])
  357. x = solveh_banded(ab, b)
  358. expected = array([[0.0, 1.0],
  359. [1.0, 0.0],
  360. [0.0, 1.0]])
  361. assert_array_almost_equal(x, expected)
  362. def test_tridiag_03_upper(self):
  363. # Solve
  364. # [ 4 1 0] [1]
  365. # [ 1 4 1] X = [4]
  366. # [ 0 1 4] [1]
  367. # with the RHS as a 2D array with shape (3,1).
  368. ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]])
  369. b = array([1.0, 4.0, 1.0]).reshape(-1, 1)
  370. x = solveh_banded(ab, b)
  371. assert_array_almost_equal(x, array([0.0, 1.0, 0.0]).reshape(-1, 1))
  372. def test_tridiag_01_lower(self):
  373. # Solve
  374. # [ 4 1 0] [1]
  375. # [ 1 4 1] X = [4]
  376. # [ 0 1 4] [1]
  377. #
  378. ab = array([[4.0, 4.0, 4.0],
  379. [1.0, 1.0, -99]])
  380. b = array([1.0, 4.0, 1.0])
  381. x = solveh_banded(ab, b, lower=True)
  382. assert_array_almost_equal(x, [0.0, 1.0, 0.0])
  383. def test_tridiag_02_lower(self):
  384. # Solve
  385. # [ 4 1 0] [1 4]
  386. # [ 1 4 1] X = [4 2]
  387. # [ 0 1 4] [1 4]
  388. #
  389. ab = array([[4.0, 4.0, 4.0],
  390. [1.0, 1.0, -99]])
  391. b = array([[1.0, 4.0],
  392. [4.0, 2.0],
  393. [1.0, 4.0]])
  394. x = solveh_banded(ab, b, lower=True)
  395. expected = array([[0.0, 1.0],
  396. [1.0, 0.0],
  397. [0.0, 1.0]])
  398. assert_array_almost_equal(x, expected)
  399. def test_tridiag_01_float32(self):
  400. # Solve
  401. # [ 4 1 0] [1]
  402. # [ 1 4 1] X = [4]
  403. # [ 0 1 4] [1]
  404. #
  405. ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]], dtype=float32)
  406. b = array([1.0, 4.0, 1.0], dtype=float32)
  407. x = solveh_banded(ab, b)
  408. assert_array_almost_equal(x, [0.0, 1.0, 0.0])
  409. def test_tridiag_02_float32(self):
  410. # Solve
  411. # [ 4 1 0] [1 4]
  412. # [ 1 4 1] X = [4 2]
  413. # [ 0 1 4] [1 4]
  414. #
  415. ab = array([[-99, 1.0, 1.0],
  416. [4.0, 4.0, 4.0]], dtype=float32)
  417. b = array([[1.0, 4.0],
  418. [4.0, 2.0],
  419. [1.0, 4.0]], dtype=float32)
  420. x = solveh_banded(ab, b)
  421. expected = array([[0.0, 1.0],
  422. [1.0, 0.0],
  423. [0.0, 1.0]])
  424. assert_array_almost_equal(x, expected)
  425. def test_tridiag_01_complex(self):
  426. # Solve
  427. # [ 4 -j 0] [ -j]
  428. # [ j 4 -j] X = [4-j]
  429. # [ 0 j 4] [4+j]
  430. #
  431. ab = array([[-99, -1.0j, -1.0j], [4.0, 4.0, 4.0]])
  432. b = array([-1.0j, 4.0-1j, 4+1j])
  433. x = solveh_banded(ab, b)
  434. assert_array_almost_equal(x, [0.0, 1.0, 1.0])
  435. def test_tridiag_02_complex(self):
  436. # Solve
  437. # [ 4 -j 0] [ -j 4j]
  438. # [ j 4 -j] X = [4-j -1-j]
  439. # [ 0 j 4] [4+j 4 ]
  440. #
  441. ab = array([[-99, -1.0j, -1.0j],
  442. [4.0, 4.0, 4.0]])
  443. b = array([[-1j, 4.0j],
  444. [4.0-1j, -1.0-1j],
  445. [4.0+1j, 4.0]])
  446. x = solveh_banded(ab, b)
  447. expected = array([[0.0, 1.0j],
  448. [1.0, 0.0],
  449. [1.0, 1.0]])
  450. assert_array_almost_equal(x, expected)
  451. def test_check_finite(self):
  452. # Solve
  453. # [ 4 1 0] [1]
  454. # [ 1 4 1] X = [4]
  455. # [ 0 1 4] [1]
  456. # with the RHS as a 1D array.
  457. ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]])
  458. b = array([1.0, 4.0, 1.0])
  459. x = solveh_banded(ab, b, check_finite=False)
  460. assert_array_almost_equal(x, [0.0, 1.0, 0.0])
  461. def test_bad_shapes(self):
  462. ab = array([[-99, 1.0, 1.0],
  463. [4.0, 4.0, 4.0]])
  464. b = array([[1.0, 4.0],
  465. [4.0, 2.0]])
  466. assert_raises(ValueError, solveh_banded, ab, b)
  467. assert_raises(ValueError, solveh_banded, ab, [1.0, 2.0])
  468. assert_raises(ValueError, solveh_banded, ab, [1.0])
  469. def test_1x1(self):
  470. x = solveh_banded([[1]], [[1, 2, 3]])
  471. assert_array_equal(x, [[1.0, 2.0, 3.0]])
  472. assert_equal(x.dtype, np.dtype('f8'))
  473. def test_native_list_arguments(self):
  474. # Same as test_01_upper, using python's native list.
  475. ab = [[0.0, 0.0, 2.0, 2.0],
  476. [-99, 1.0, 1.0, 1.0],
  477. [4.0, 4.0, 4.0, 4.0]]
  478. b = [1.0, 4.0, 1.0, 2.0]
  479. x = solveh_banded(ab, b)
  480. assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0])
  481. @pytest.mark.parametrize('dt_ab', [int, float, np.float32, complex, np.complex64])
  482. @pytest.mark.parametrize('dt_b', [int, float, np.float32, complex, np.complex64])
  483. def test_empty(self, dt_ab, dt_b):
  484. # ab contains one empty row corresponding to the diagonal
  485. ab = np.array([[]], dtype=dt_ab)
  486. b = np.array([], dtype=dt_b)
  487. x = solveh_banded(ab, b)
  488. assert x.shape == (0,)
  489. assert x.dtype == solve(np.eye(1, dtype=dt_ab), np.ones(1, dtype=dt_b)).dtype
  490. b = np.empty((0, 0), dtype=dt_b)
  491. x = solveh_banded(ab, b)
  492. assert x.shape == (0, 0)
  493. assert x.dtype == solve(np.eye(1, dtype=dt_ab), np.ones(1, dtype=dt_b)).dtype
  494. class TestSolve:
  495. def test_20Feb04_bug(self):
  496. a = [[1, 1], [1.0, 0]] # ok
  497. x0 = solve(a, [1, 0j])
  498. assert_array_almost_equal(dot(a, x0), [1, 0])
  499. # gives failure with clapack.zgesv(..,rowmajor=0)
  500. a = [[1, 1], [1.2, 0]]
  501. b = [1, 0j]
  502. x0 = solve(a, b)
  503. assert_array_almost_equal(dot(a, x0), [1, 0])
  504. def test_simple(self):
  505. a = [[1, 20], [-30, 4]]
  506. for b in ([[1, 0], [0, 1]],
  507. [1, 0],
  508. [[2, 1], [-30, 4]]
  509. ):
  510. x = solve(a, b)
  511. assert_array_almost_equal(dot(a, x), b)
  512. def test_simple_complex(self):
  513. a = array([[5, 2], [2j, 4]], 'D')
  514. for b in ([1j, 0],
  515. [[1j, 1j], [0, 2]],
  516. [1, 0j],
  517. array([1, 0], 'D'),
  518. ):
  519. x = solve(a, b)
  520. assert_array_almost_equal(dot(a, x), b)
  521. def test_simple_pos(self):
  522. a = [[2, 3], [3, 5]]
  523. for lower in [0, 1]:
  524. for b in ([[1, 0], [0, 1]],
  525. [1, 0]
  526. ):
  527. x = solve(a, b, assume_a='pos', lower=lower)
  528. assert_array_almost_equal(dot(a, x), b)
  529. def test_simple_pos_complexb(self):
  530. a = [[5, 2], [2, 4]]
  531. for b in ([1j, 0],
  532. [[1j, 1j], [0, 2]],
  533. ):
  534. x = solve(a, b, assume_a='pos')
  535. assert_array_almost_equal(dot(a, x), b)
  536. def test_simple_sym(self):
  537. a = [[2, 3], [3, -5]]
  538. for lower in [0, 1]:
  539. for b in ([[1, 0], [0, 1]],
  540. [1, 0]
  541. ):
  542. x = solve(a, b, assume_a='sym', lower=lower)
  543. assert_array_almost_equal(dot(a, x), b)
  544. def test_simple_sym_complexb(self):
  545. a = [[5, 2], [2, -4]]
  546. for b in ([1j, 0],
  547. [[1j, 1j], [0, 2]]
  548. ):
  549. x = solve(a, b, assume_a='sym')
  550. assert_array_almost_equal(dot(a, x), b)
  551. def test_simple_sym_complex(self):
  552. a = [[5, 2+1j], [2+1j, -4]]
  553. for b in ([1j, 0],
  554. [1, 0],
  555. [[1j, 1j], [0, 2]]
  556. ):
  557. x = solve(a, b, assume_a='sym')
  558. assert_array_almost_equal(dot(a, x), b)
  559. def test_simple_her_actuallysym(self):
  560. a = [[2, 3], [3, -5]]
  561. for lower in [0, 1]:
  562. for b in ([[1, 0], [0, 1]],
  563. [1, 0],
  564. [1j, 0],
  565. ):
  566. x = solve(a, b, assume_a='her', lower=lower)
  567. assert_array_almost_equal(dot(a, x), b)
  568. def test_simple_her(self):
  569. a = [[5, 2+1j], [2-1j, -4]]
  570. for b in ([1j, 0],
  571. [1, 0],
  572. [[1j, 1j], [0, 2]]
  573. ):
  574. x = solve(a, b, assume_a='her')
  575. assert_array_almost_equal(dot(a, x), b)
  576. def test_nils_20Feb04(self):
  577. rng = np.random.default_rng(1234)
  578. n = 2
  579. A = rng.random([n, n])+rng.random([n, n])*1j
  580. X = zeros((n, n), 'D')
  581. Ainv = inv(A)
  582. R = identity(n)+identity(n)*0j
  583. for i in arange(0, n):
  584. r = R[:, i]
  585. X[:, i] = solve(A, r)
  586. assert_array_almost_equal(X, Ainv)
  587. def test_random(self):
  588. rng = np.random.default_rng(1234)
  589. n = 20
  590. a = rng.random([n, n])
  591. for i in range(n):
  592. a[i, i] = 20*(.1+a[i, i])
  593. for i in range(4):
  594. b = rng.random([n, 3])
  595. x = solve(a, b)
  596. assert_array_almost_equal(dot(a, x), b)
  597. def test_random_complex(self):
  598. rng = np.random.default_rng(1234)
  599. n = 20
  600. a = rng.random([n, n]) + 1j * rng.random([n, n])
  601. for i in range(n):
  602. a[i, i] = 20*(.1+a[i, i])
  603. for i in range(2):
  604. b = rng.random([n, 3])
  605. x = solve(a, b)
  606. assert_array_almost_equal(dot(a, x), b)
  607. def test_random_sym(self):
  608. rng = np.random.default_rng(1234)
  609. n = 20
  610. a = rng.random([n, n])
  611. for i in range(n):
  612. a[i, i] = abs(20*(.1+a[i, i]))
  613. for j in range(i):
  614. a[i, j] = a[j, i]
  615. for i in range(4):
  616. b = rng.random([n])
  617. x = solve(a, b, assume_a="pos")
  618. assert_array_almost_equal(dot(a, x), b)
  619. def test_random_sym_complex(self):
  620. rng = np.random.default_rng(1234)
  621. n = 20
  622. a = rng.random([n, n])
  623. a = a + 1j*rng.random([n, n])
  624. for i in range(n):
  625. a[i, i] = abs(20*(.1+a[i, i]))
  626. for j in range(i):
  627. a[i, j] = conjugate(a[j, i])
  628. b = rng.random([n])+2j*rng.random([n])
  629. for i in range(2):
  630. x = solve(a, b, assume_a="pos")
  631. assert_array_almost_equal(dot(a, x), b)
  632. def test_check_finite(self):
  633. a = [[1, 20], [-30, 4]]
  634. for b in ([[1, 0], [0, 1]], [1, 0],
  635. [[2, 1], [-30, 4]]):
  636. x = solve(a, b, check_finite=False)
  637. assert_array_almost_equal(dot(a, x), b)
  638. def test_scalar_a_and_1D_b(self):
  639. a = 1
  640. b = [1, 2, 3]
  641. x = solve(a, b)
  642. assert_array_almost_equal(x.ravel(), b)
  643. assert_(x.shape == (3,), 'Scalar_a_1D_b test returned wrong shape')
  644. def test_simple2(self):
  645. a = np.array([[1.80, 2.88, 2.05, -0.89],
  646. [525.00, -295.00, -95.00, -380.00],
  647. [1.58, -2.69, -2.90, -1.04],
  648. [-1.11, -0.66, -0.59, 0.80]])
  649. b = np.array([[9.52, 18.47],
  650. [2435.00, 225.00],
  651. [0.77, -13.28],
  652. [-6.22, -6.21]])
  653. x = solve(a, b)
  654. assert_array_almost_equal(x, np.array([[1., -1, 3, -5],
  655. [3, 2, 4, 1]]).T)
  656. def test_simple_complex2(self):
  657. a = np.array([[-1.34+2.55j, 0.28+3.17j, -6.39-2.20j, 0.72-0.92j],
  658. [-1.70-14.10j, 33.10-1.50j, -1.50+13.40j, 12.90+13.80j],
  659. [-3.29-2.39j, -1.91+4.42j, -0.14-1.35j, 1.72+1.35j],
  660. [2.41+0.39j, -0.56+1.47j, -0.83-0.69j, -1.96+0.67j]])
  661. b = np.array([[26.26+51.78j, 31.32-6.70j],
  662. [64.30-86.80j, 158.60-14.20j],
  663. [-5.75+25.31j, -2.15+30.19j],
  664. [1.16+2.57j, -2.56+7.55j]])
  665. x = solve(a, b)
  666. assert_array_almost_equal(x, np. array([[1+1.j, -1-2.j],
  667. [2-3.j, 5+1.j],
  668. [-4-5.j, -3+4.j],
  669. [6.j, 2-3.j]]))
  670. @pytest.mark.parametrize("assume_a", ['her', 'sym'])
  671. def test_symmetric_hermitian(self, assume_a):
  672. # An upper triangular matrix will be used for symmetric/hermitian matrix a
  673. a = np.array([[-1.84, 0.11-0.11j, -1.78-1.18j, 3.91-1.50j],
  674. [0, -4.63, -1.84+0.03j, 2.21+0.21j],
  675. [0, 0, -8.87, 1.58-0.90j],
  676. [0, 0, 0, -1.36]])
  677. b = np.array([[2.98-10.18j, 28.68-39.89j],
  678. [-9.58+3.88j, -24.79-8.40j],
  679. [-0.77-16.05j, 4.23-70.02j],
  680. [7.79+5.48j, -35.39+18.01j]])
  681. a2 = a.T if assume_a == 'sym' else a.conj().T # for testing `lower`
  682. a3 = a + a2 # for reference solution
  683. a3[np.arange(4), np.arange(4)] = np.diag(a)
  684. ref = solve(a3, b, assume_a='general')
  685. x = solve(a, b, assume_a=assume_a)
  686. assert_array_almost_equal(x, ref)
  687. # Also transpose(/conjugate) `a` and test for lower triangular data
  688. # This also tests gh-22265 resolution; otherwise, a warning would be emitted
  689. x = solve(a2, b, assume_a=assume_a, lower=True)
  690. assert_array_almost_equal(x, ref)
  691. def test_pos_and_sym(self):
  692. A = np.arange(1, 10).reshape(3, 3)
  693. x = solve(np.tril(A)/9, np.ones(3), assume_a='pos')
  694. assert_array_almost_equal(x, [9., 1.8, 1.])
  695. x = solve(np.tril(A)/9, np.ones(3), assume_a='sym')
  696. assert_array_almost_equal(x, [9., 1.8, 1.])
  697. def test_singularity(self):
  698. a = np.array([[1, 0, 0, 0, 0, 0, 1, 0, 1],
  699. [1, 1, 1, 0, 0, 0, 1, 0, 1],
  700. [0, 1, 1, 0, 0, 0, 1, 0, 1],
  701. [1, 0, 1, 1, 1, 1, 0, 0, 0],
  702. [1, 0, 1, 1, 1, 1, 0, 0, 0],
  703. [1, 0, 1, 1, 1, 1, 0, 0, 0],
  704. [1, 0, 1, 1, 1, 1, 0, 0, 0],
  705. [1, 1, 1, 1, 1, 1, 1, 1, 1],
  706. [1, 1, 1, 1, 1, 1, 1, 1, 1]])
  707. b = np.arange(9)[:, None]
  708. assert_raises(LinAlgError, solve, a, b)
  709. @pytest.mark.parametrize('structure',
  710. ('diagonal', 'tridiagonal', 'lower triangular',
  711. 'upper triangular', 'symmetric', 'hermitian',
  712. 'positive definite', 'general', 'banded', None))
  713. def test_ill_condition_warning(self, structure):
  714. rng = np.random.default_rng(234859349452)
  715. n = 10
  716. d = np.logspace(0, 50, n)
  717. A = np.diag(d)
  718. b = rng.random(size=n)
  719. message = "(Ill-conditioned matrix|An ill-conditioned matrix)"
  720. with pytest.warns(LinAlgWarning, match=message):
  721. solve(A, b, assume_a=structure)
  722. @pytest.mark.parametrize('structure',
  723. ('diagonal', 'tridiagonal', 'lower triangular',
  724. 'upper triangular', 'symmetric', 'hermitian',
  725. 'positive definite', 'general', None))
  726. def test_exactly_singular_gh22263(self, structure):
  727. n = 10
  728. A = np.zeros((n, n))
  729. b = np.ones(n)
  730. with (pytest.raises(LinAlgError, match="singular"), np.errstate(all='ignore')):
  731. solve(A, b, assume_a=structure)
  732. @pytest.mark.parametrize('b', [0, 1, [0, 1]])
  733. def test_singular_scalar(self, b):
  734. # regression test for gh-24355: scalar a=0 is singular
  735. # thus should raise the same error
  736. with pytest.raises(LinAlgError):
  737. a = np.zeros((1, 1))
  738. solve(a, b)
  739. with pytest.raises(LinAlgError):
  740. solve(0, b)
  741. with pytest.raises(LinAlgError):
  742. solve([[0]], b)
  743. def test_multiple_rhs(self):
  744. a = np.eye(2)
  745. rng = np.random.default_rng(1234)
  746. b = rng.random((2, 12))
  747. x = solve(a, b)
  748. assert_array_almost_equal(x, b)
  749. def test_transposed_keyword(self):
  750. A = np.arange(9).reshape(3, 3) + 1
  751. x = solve(np.tril(A)/9, np.ones(3), transposed=True)
  752. assert_array_almost_equal(x, [1.2, 0.2, 1])
  753. x = solve(np.tril(A)/9, np.ones(3), transposed=False)
  754. assert_array_almost_equal(x, [9, -5.4, -1.2])
  755. @pytest.mark.skip(reason="1. why? 2. deprecate the kwarg altogether?")
  756. def test_transposed_notimplemented(self):
  757. a = np.eye(3).astype(complex)
  758. with assert_raises(NotImplementedError):
  759. solve(a, a, transposed=True)
  760. def test_nonsquare_a(self):
  761. assert_raises(ValueError, solve, [1, 2], 1)
  762. def test_size_mismatch_with_1D_b(self):
  763. assert_array_almost_equal(solve(np.eye(3), np.ones(3)), np.ones(3))
  764. assert_raises(ValueError, solve, np.eye(3), np.ones(4))
  765. def test_assume_a_keyword(self):
  766. assert_raises(ValueError, solve, 1, 1, assume_a='zxcv')
  767. @pytest.mark.parametrize("size", [10, 100])
  768. @pytest.mark.parametrize("assume_a", ['gen', 'sym', 'pos', 'her', 'tridiagonal'])
  769. @pytest.mark.parametrize(
  770. "dtype", [np.float32, np.float64, np.complex64, np.complex128]
  771. )
  772. def test_all_type_size_routine_combinations(self, size, dtype, assume_a):
  773. rng = np.random.default_rng(1234)
  774. is_complex = dtype in (np.complex64, np.complex128)
  775. a = rng.standard_normal((size, size)).astype(dtype)
  776. b = rng.standard_normal(size).astype(dtype)
  777. if is_complex:
  778. a += (1j*rng.standard_normal((size, size))).astype(dtype)
  779. if assume_a == 'sym': # Can still be complex but only symmetric
  780. a = a + a.T
  781. elif assume_a == 'her': # Handle hermitian matrices here instead
  782. a = a + a.T.conj()
  783. elif assume_a == 'pos':
  784. a = a.T.conj() @ a + 0.1*np.eye(size)
  785. elif assume_a == 'tridiagonal':
  786. a = (np.diag(np.diag(a)) +
  787. np.diag(np.diag(a, 1), 1) +
  788. np.diag(np.diag(a, -1), -1)
  789. )
  790. tol = 1e-12 if dtype in (np.float64, np.complex128) else 1e-6
  791. if assume_a in ['gen', 'sym', 'her']:
  792. # We revert the tolerance from before
  793. # 4b4a6e7c34fa4060533db38f9a819b98fa81476c
  794. if dtype in (np.float32, np.complex64):
  795. tol *= 10
  796. x = solve(a, b, assume_a=assume_a)
  797. assert_allclose(a @ x, b, atol=tol * size, rtol=tol * size)
  798. if assume_a == 'sym' and not is_complex:
  799. x = solve(a, b, assume_a=assume_a, transposed=True)
  800. assert_allclose(a @ x, b, atol=tol * size, rtol=tol * size)
  801. @pytest.mark.parametrize('dt_a', [int, float, np.float32, complex, np.complex64])
  802. @pytest.mark.parametrize('dt_b', [int, float, np.float32, complex, np.complex64])
  803. def test_empty(self, dt_a, dt_b):
  804. a = np.empty((0, 0), dtype=dt_a)
  805. b = np.empty(0, dtype=dt_b)
  806. x = solve(a, b)
  807. assert x.size == 0
  808. dt_nonempty = solve(np.eye(2, dtype=dt_a), np.ones(2, dtype=dt_b)).dtype
  809. assert x.dtype == dt_nonempty
  810. assert x.shape == np.linalg.solve(a, b).shape
  811. a = np.ones((3, 0, 2, 2), dtype=dt_a)
  812. b = np.ones((2, 4), dtype=dt_b)
  813. x = solve(a, b)
  814. assert x.shape == (3, 0, 2, 4)
  815. assert x.dtype == dt_nonempty
  816. def test_empty_rhs(self):
  817. a = np.eye(2)
  818. b = [[], []]
  819. x = solve(a, b)
  820. assert_(x.size == 0, 'Returned array is not empty')
  821. assert_(x.shape == (2, 0), 'Returned empty array shape is wrong')
  822. @pytest.mark.parametrize('dtype', [np.float64, np.complex128])
  823. @pytest.mark.parametrize('assume_a', ['diagonal', 'tridiagonal', 'banded',
  824. 'lower triangular', 'upper triangular',
  825. 'pos', 'positive definite',
  826. 'symmetric', 'hermitian', 'banded',
  827. 'general', 'sym', 'her', 'gen'])
  828. @pytest.mark.parametrize('nrhs', [(), (5,)])
  829. @pytest.mark.parametrize('transposed', [True, False])
  830. @pytest.mark.parametrize('overwrite', [True, False])
  831. @pytest.mark.parametrize('fortran', [True, False])
  832. def test_structure_detection(self, dtype, assume_a, nrhs, transposed,
  833. overwrite, fortran):
  834. rng = np.random.default_rng(982345982439826)
  835. n = 5 if not assume_a == 'banded' else 20
  836. b = rng.random(size=(n,) + nrhs)
  837. A = rng.random(size=(n, n))
  838. if np.issubdtype(dtype, np.complexfloating):
  839. b = b + rng.random(size=(n,) + nrhs) * 1j
  840. A = A + rng.random(size=(n, n)) * 1j
  841. if assume_a == 'diagonal':
  842. A = np.diag(np.diag(A))
  843. elif assume_a == 'lower triangular':
  844. A = np.tril(A)
  845. elif assume_a == 'upper triangular':
  846. A = np.triu(A)
  847. elif assume_a == 'tridiagonal':
  848. A = (np.diag(np.diag(A))
  849. + np.diag(np.diag(A, -1), -1)
  850. + np.diag(np.diag(A, 1), 1))
  851. elif assume_a == 'banded':
  852. A = np.triu(np.tril(A, 2), -1)
  853. elif assume_a in {'symmetric', 'sym'}:
  854. A = A + A.T
  855. elif assume_a in {'hermitian', 'her'}:
  856. A = A + A.conj().T
  857. elif assume_a in {'positive definite', 'pos'}:
  858. A = A @ A.T.conj()
  859. if fortran:
  860. A = np.asfortranarray(A)
  861. A_copy = A.copy(order='A')
  862. b_copy = b.copy()
  863. if np.issubdtype(dtype, np.complexfloating) and transposed:
  864. message = "scipy.linalg.solve can currently..."
  865. with pytest.raises(NotImplementedError, match=message):
  866. solve(A, b, overwrite_a=overwrite, overwrite_b=overwrite,
  867. transposed=transposed)
  868. return
  869. res = solve(A, b, overwrite_a=overwrite, overwrite_b=overwrite,
  870. transposed=transposed, assume_a=assume_a)
  871. # Check that solution this solution is *correct*
  872. ref = np.linalg.solve(A_copy.T if transposed else A_copy, b_copy)
  873. assert_allclose(res, ref)
  874. # Check that `solve` correctly identifies the structure and returns
  875. # *exactly* the same solution whether `assume_a` is specified or not
  876. if assume_a != 'banded': # structure detection removed for banded
  877. assert_allclose(
  878. solve(A_copy, b_copy, transposed=transposed), res, atol=1e-15
  879. )
  880. # Check that overwrite was respected
  881. if not overwrite:
  882. assert_equal(A, A_copy)
  883. assert_equal(b, b_copy)
  884. @pytest.mark.skipif(
  885. np.__version__ < '2', reason="solve chokes on b.ndim == 1 in numpy < 2"
  886. )
  887. @pytest.mark.parametrize(
  888. "assume_a",
  889. [
  890. None, "diagonal", "general", "upper triangular", "lower triangular", "pos",
  891. ]
  892. )
  893. def test_vs_np_solve(self, assume_a):
  894. e = np.eye(2)
  895. a = np.arange(1, 4*3*2 + 1).reshape((4, 3, 2, 1, 1)) * e
  896. b = np.ones(2)
  897. assert_allclose(solve(a, b, assume_a=assume_a), np.linalg.solve(a, b))
  898. b = np.ones((2, 1))
  899. assert_allclose(solve(a, b, assume_a=assume_a), np.linalg.solve(a, b))
  900. b = np.ones((2, 2)) * [1, 2]
  901. assert_allclose(solve(a, b, assume_a=assume_a), np.linalg.solve(a, b))
  902. def test_pos_lower(self):
  903. # regression test for
  904. # https://github.com/scipy/scipy/pull/23071#issuecomment-3085826112
  905. rng = np.random.default_rng(0)
  906. a = rng.normal(size=(4, 4))
  907. a = np.tril(np.matmul(a, np.conj(a.T))) # lower triangle of hermitian array
  908. b = rng.normal(size=(4, 2))
  909. out = solve(a, b, assume_a='pos', lower=True)
  910. aa = a + a.T - np.diag(np.diag(a)) # the full hermitian array
  911. result_np = np.linalg.solve(aa, b)
  912. assert_allclose(out, result_np, atol=1e-15)
  913. # repeat with uplo='U'
  914. out = solve(a.T, b, assume_a='pos', lower=False)
  915. assert_allclose(out, result_np, atol=1e-15)
  916. def test_pos_fails_sym_complex(self):
  917. # regression test for the `solve` analog of gh-24359
  918. # the matrix is 1) symmetric not hermitian, and 2) not positive definite:
  919. a = np.asarray([[ 182.56985285-64.28859483j, -177.24879835+11.0780499j ],
  920. [-177.24879835+11.0780499j , 177.24879835-11.0780499j ]])
  921. b = np.eye(2)
  922. ainv = solve(a, b)
  923. assert_allclose(ainv @ a, np.eye(2), atol=1e-14)
  924. ainv_sym = solve(a, b, assume_a="sym")
  925. assert_allclose(ainv_sym, ainv, atol=1e-14)
  926. # Specifying assume_a="pos" disables the structure detection, and directly
  927. # calls LAPACK routines zportf and zpotri.
  928. # Since zportf(a) does not error out, neither does solve.
  929. ainv_chol = solve(a, b, assume_a="pos")
  930. assert not np.allclose(ainv, ainv_chol, atol=1e-14)
  931. # Setting assume_a="pos" with a non-pos def matrix returned nonsense.
  932. # This is at least consistent with inv.
  933. ainv_inv = inv(a, assume_a="pos")
  934. assert_allclose(ainv_chol, ainv_inv, atol=1e-14)
  935. def test_readonly(self):
  936. a = np.eye(3)
  937. a.flags.writeable = False
  938. b = np.ones(3)
  939. x = solve(a, b)
  940. assert_allclose(x, b, atol=1e-14)
  941. @parametrize_overwrite_arg
  942. def test_batch_negative_stride(self, overwrite_kw):
  943. a = np.arange(3*8).reshape(2, 3, 2, 2)
  944. a = a[:, ::-1, :, :]
  945. b = np.ones(2)
  946. x = solve(a, b, **overwrite_kw)
  947. assert x.shape == a.shape[:-1]
  948. assert_allclose(a @ x[..., None] - b, 0, atol=1e-14)
  949. # use b with a negative stride now
  950. b = np.ones((2, 4))[:, ::-1]
  951. x = solve(a, b, **overwrite_kw)
  952. assert x.shape == a.shape[:-1] + (b.shape[-1],)
  953. assert_allclose(a @ x - b, 0, atol=1e-14)
  954. @parametrize_overwrite_arg
  955. def test_core_negative_stride(self, overwrite_kw):
  956. a = np.arange(3*8).reshape(2, 3, 2, 2)
  957. a = a[:, :, ::-1, :]
  958. b = np.ones(2)
  959. x = solve(a, b, **overwrite_kw)
  960. assert x.shape == a.shape[:-1]
  961. assert_allclose(a @ x[..., None] - b, 0, atol=1e-14)
  962. # use b with a negative stride now
  963. b = np.ones((2, 4))[::-1, :]
  964. x = solve(a, b, **overwrite_kw)
  965. assert x.shape == a.shape[:-1] + (b.shape[-1],)
  966. assert_allclose(a @ x - b, 0, atol=1e-14)
  967. @parametrize_overwrite_arg
  968. def test_core_non_contiguous(self, overwrite_kw):
  969. a = np.arange(3*8*2).reshape(2, 3, 2, 4)
  970. a = a[..., ::2]
  971. b = np.ones(2)
  972. x = solve(a, b, **overwrite_kw)
  973. assert x.shape == a.shape[:-1]
  974. assert_allclose(a @ x[..., None] - b, 0, atol=1e-14)
  975. # use strided b now
  976. b = np.ones(4)[::2]
  977. x = solve(a, b, **overwrite_kw)
  978. assert x.shape == a.shape[:-1]
  979. assert_allclose(a @ x[..., None] - b, 0, atol=1e-14)
  980. @parametrize_overwrite_arg
  981. def test_batch_non_contiguous(self, overwrite_kw):
  982. a = np.arange(3*8*2).reshape(2, 6, 2, 2)
  983. a = a[:, ::2, ...]
  984. b = np.ones(2)
  985. x = solve(a, b, **overwrite_kw)
  986. assert x.shape == a.shape[:-1]
  987. assert_allclose(a @ x[..., None] - b, 0, atol=1e-14)
  988. # use strided b now
  989. b = np.ones((2, 6))[:, ::2]
  990. x = solve(a, b, **overwrite_kw)
  991. assert x.shape == a.shape[:-1] + (b.shape[-1],)
  992. assert_allclose(a @ x - b, 0, atol=1e-14)
  993. @parametrize_overwrite_arg
  994. def test_batch_weird_strides(self, overwrite_kw):
  995. a = np.arange(3*8*2).reshape(2, 3, 2, 2, 2)
  996. a = a.transpose(1, 3, 4, 0, 2)
  997. b = np.ones(2)
  998. x = solve(a, b, **overwrite_kw)
  999. assert x.shape == a.shape[:-1]
  1000. assert_allclose(a @ x[..., None] - b, 0, atol=1e-14)
  1001. @parametrize_overwrite_arg
  1002. @parametrize_overwrite_b_arg
  1003. @pytest.mark.parametrize('a_dtype', [int, float])
  1004. @pytest.mark.parametrize('a_order', ['C', 'F'])
  1005. @pytest.mark.parametrize('b_dtype', [int, float])
  1006. @pytest.mark.parametrize('b_order', ['C', 'F'])
  1007. @pytest.mark.parametrize('b_ndim', [1, 2]) # XXX ndim > 2
  1008. @pytest.mark.parametrize('transposed', [True, False])
  1009. def test_overwrite_args(
  1010. self, overwrite_kw, overwrite_b_kw, a_dtype, a_order,
  1011. b_dtype, b_order, b_ndim, transposed
  1012. ):
  1013. n = 3
  1014. a = np.arange(1, n**2 + 1).reshape(n, n) + np.eye(n)
  1015. a = a.astype(a_dtype, order=a_order)
  1016. b = np.arange(n)
  1017. if b_ndim > 1:
  1018. b = np.stack([b*j for j in range(b_ndim)]).T
  1019. b = b.astype(b_dtype, order=b_order)
  1020. a_ref = a.copy()
  1021. b_ref = b.copy()
  1022. # solve and check that the solution is correct for all parameters
  1023. x = solve(a, b, **overwrite_kw, **overwrite_b_kw, transposed=transposed)
  1024. a_or_aT = a_ref.T if transposed else a_ref
  1025. assert_allclose(a_or_aT @ x, b_ref, atol=1e-14)
  1026. # now check that it worked in-place where expected
  1027. overwrite_a = overwrite_kw.get('overwrite_a', False)
  1028. a_inplace = overwrite_a and (a.dtype != int) and a.flags['F_CONTIGUOUS']
  1029. overwrite_b = overwrite_b_kw.get('overwrite_b', False)
  1030. b_inplace = overwrite_b and (b.dtype != int) and b.flags['F_CONTIGUOUS']
  1031. assert np.shares_memory(x, b) == b_inplace
  1032. assert (b == b_ref).all() != b_inplace
  1033. assert (a == a_ref).all() != a_inplace
  1034. def test_posdef_not_posdef(self):
  1035. # the `b` matrix is invertible but not positive definite
  1036. a = np.arange(9).reshape(3, 3)
  1037. A = a + a.T + np.eye(3)
  1038. b = np.ones(3)
  1039. # cholesky solver fails, and the routine falls back to the general inverse
  1040. x0 = solve(A, b)
  1041. assert_allclose(A @ x0, b, atol=1e-14)
  1042. # but it does not fall back if `assume_a` is given
  1043. with assert_raises(LinAlgError):
  1044. solve(A, b, assume_a='pos')
  1045. def test_diagonal(self):
  1046. a = np.stack([np.triu(np.ones((3, 3))), np.diag(np.arange(1, 4))])
  1047. b = np.ones(3)
  1048. x = solve(a, b)
  1049. # basic diagonal solve
  1050. assert_allclose(x[1, ...], 1 / np.arange(1, 4), atol=1e-14)
  1051. # ill-conditioned inputs warn
  1052. a = np.asarray([[1e30, 0], [0, 1]])
  1053. b = np.ones(2)
  1054. with pytest.warns(LinAlgWarning):
  1055. solve(a, b, assume_a="diagonal")
  1056. # singular input raises
  1057. a = np.asarray([[0, 0], [0, 1]])
  1058. b = np.ones(2)
  1059. with pytest.raises(LinAlgError):
  1060. solve(a, b, assume_a="diagonal")
  1061. def test_tridiagonal(self):
  1062. n = 4
  1063. a = -2*np.diag(np.ones(n)) + np.diag(np.ones(3), 1) + np.diag(np.ones(3), -1)
  1064. a = np.stack([np.triu(np.ones((n, n))), a])
  1065. b = np.ones(4)
  1066. x = solve(a, b)
  1067. # basic tridiag solve
  1068. assert_allclose(x[1, ...], np.asarray([-2., -3., -3., -2.]), atol=1e-15)
  1069. # ill-conditioned inputs warn
  1070. a[1, 0, 0] = 1e20
  1071. with pytest.warns(LinAlgWarning):
  1072. solve(a, b, assume_a="tridiagonal")
  1073. # singular inputss raise
  1074. a[1, 0, 0] = a[1, 0, 1] = 0
  1075. with pytest.raises(LinAlgError):
  1076. solve(a, b, assume_a="tridiagonal")
  1077. class TestSolveTriangular:
  1078. def test_simple(self):
  1079. """
  1080. solve_triangular on a simple 2x2 matrix.
  1081. """
  1082. A = array([[1, 0], [1, 2]])
  1083. b = [1, 1]
  1084. sol = solve_triangular(A, b, lower=True)
  1085. assert_array_almost_equal(sol, [1, 0])
  1086. # check that it works also for non-contiguous matrices
  1087. sol = solve_triangular(A.T, b, lower=False)
  1088. assert_array_almost_equal(sol, [.5, .5])
  1089. # and that it gives the same result as trans=1
  1090. sol = solve_triangular(A, b, lower=True, trans=1)
  1091. assert_array_almost_equal(sol, [.5, .5])
  1092. b = identity(2)
  1093. sol = solve_triangular(A, b, lower=True, trans=1)
  1094. assert_array_almost_equal(sol, [[1., -.5], [0, 0.5]])
  1095. def test_simple_complex(self):
  1096. """
  1097. solve_triangular on a simple 2x2 complex matrix
  1098. """
  1099. A = array([[1+1j, 0], [1j, 2]])
  1100. b = identity(2)
  1101. sol = solve_triangular(A, b, lower=True, trans=1)
  1102. assert_array_almost_equal(sol, [[.5-.5j, -.25-.25j], [0, 0.5]])
  1103. # check other option combinations with complex rhs
  1104. b = np.diag([1+1j, 1+2j])
  1105. sol = solve_triangular(A, b, lower=True, trans=0)
  1106. assert_array_almost_equal(sol, [[1, 0], [-0.5j, 0.5+1j]])
  1107. sol = solve_triangular(A, b, lower=True, trans=1)
  1108. assert_array_almost_equal(sol, [[1, 0.25-0.75j], [0, 0.5+1j]])
  1109. sol = solve_triangular(A, b, lower=True, trans=2)
  1110. assert_array_almost_equal(sol, [[1j, -0.75-0.25j], [0, 0.5+1j]])
  1111. sol = solve_triangular(A.T, b, lower=False, trans=0)
  1112. assert_array_almost_equal(sol, [[1, 0.25-0.75j], [0, 0.5+1j]])
  1113. sol = solve_triangular(A.T, b, lower=False, trans=1)
  1114. assert_array_almost_equal(sol, [[1, 0], [-0.5j, 0.5+1j]])
  1115. sol = solve_triangular(A.T, b, lower=False, trans=2)
  1116. assert_array_almost_equal(sol, [[1j, 0], [-0.5, 0.5+1j]])
  1117. def test_check_finite(self):
  1118. """
  1119. solve_triangular on a simple 2x2 matrix.
  1120. """
  1121. A = array([[1, 0], [1, 2]])
  1122. b = [1, 1]
  1123. sol = solve_triangular(A, b, lower=True, check_finite=False)
  1124. assert_array_almost_equal(sol, [1, 0])
  1125. @pytest.mark.parametrize('dt_a', [int, float, np.float32, complex, np.complex64])
  1126. @pytest.mark.parametrize('dt_b', [int, float, np.float32, complex, np.complex64])
  1127. def test_empty(self, dt_a, dt_b):
  1128. a = np.empty((0, 0), dtype=dt_a)
  1129. b = np.empty(0, dtype=dt_b)
  1130. x = solve_triangular(a, b)
  1131. assert x.size == 0
  1132. dt_nonempty = solve_triangular(
  1133. np.eye(2, dtype=dt_a), np.ones(2, dtype=dt_b)
  1134. ).dtype
  1135. assert x.dtype == dt_nonempty
  1136. def test_empty_rhs(self):
  1137. a = np.eye(2)
  1138. b = [[], []]
  1139. x = solve_triangular(a, b)
  1140. assert_(x.size == 0, 'Returned array is not empty')
  1141. assert_(x.shape == (2, 0), 'Returned empty array shape is wrong')
  1142. class TestInv:
  1143. def test_simple(self):
  1144. a = [[1, 2], [3, 4]]
  1145. a_inv = inv(a)
  1146. assert_array_almost_equal(dot(a, a_inv), np.eye(2))
  1147. a = [[1, 2, 3], [4, 5, 6], [7, 8, 10]]
  1148. a_inv = inv(a)
  1149. assert_array_almost_equal(dot(a, a_inv), np.eye(3))
  1150. def test_random(self):
  1151. rng = np.random.default_rng(1234)
  1152. n = 20
  1153. for i in range(4):
  1154. a = rng.random([n, n])
  1155. for i in range(n):
  1156. a[i, i] = 20*(.1+a[i, i])
  1157. a_inv = inv(a)
  1158. assert_array_almost_equal(dot(a, a_inv),
  1159. identity(n))
  1160. def test_simple_complex(self):
  1161. a = [[1, 2], [3, 4j]]
  1162. a_inv = inv(a)
  1163. assert_array_almost_equal(dot(a, a_inv), [[1, 0], [0, 1]])
  1164. def test_random_complex(self):
  1165. rng = np.random.default_rng(1234)
  1166. n = 20
  1167. for i in range(4):
  1168. a = rng.random([n, n])+2j*rng.random([n, n])
  1169. for i in range(n):
  1170. a[i, i] = 20*(.1+a[i, i])
  1171. a_inv = inv(a)
  1172. assert_array_almost_equal(dot(a, a_inv),
  1173. identity(n))
  1174. def test_check_finite(self):
  1175. a = [[1, 2], [3, 4]]
  1176. a_inv = inv(a, check_finite=False)
  1177. assert_array_almost_equal(dot(a, a_inv), [[1, 0], [0, 1]])
  1178. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  1179. def test_empty(self, dt):
  1180. a = np.empty((0, 0), dtype=dt)
  1181. a_inv = inv(a)
  1182. assert a_inv.size == 0
  1183. assert a_inv.dtype == inv(np.eye(2, dtype=dt)).dtype
  1184. a = np.ones((3, 0, 2, 2), dtype=dt)
  1185. a_inv = inv(a)
  1186. assert a_inv.shape == (3, 0, 2, 2)
  1187. a = np.ones((3, 1, 0, 0), dtype=dt)
  1188. a_inv = inv(a)
  1189. assert a_inv.shape == (3, 1, 0, 0)
  1190. @parametrize_overwrite_arg
  1191. def test_overwrite_a(self, overwrite_kw):
  1192. n = 3
  1193. a0 = np.arange(1, n**2 + 1).reshape(n, n) + np.eye(n)
  1194. # int arrays are copied internally
  1195. a = a0.copy()
  1196. a_inv = inv(a, **overwrite_kw)
  1197. assert_allclose(a_inv @ a, np.eye(n), atol=1e-14)
  1198. assert_equal(a, a0)
  1199. assert not np.shares_memory(a, a_inv)
  1200. # float C ordered arrays are copied, too
  1201. a = a0.copy().astype(float)
  1202. a_inv = inv(a, **overwrite_kw)
  1203. assert_allclose(a_inv @ a0, np.eye(n), atol=1e-14)
  1204. assert_equal(a, a0)
  1205. assert not np.shares_memory(a, a_inv)
  1206. # 2D F-ordered arrays of LAPACK-compatible dtypes: inv works inplace.
  1207. # IOW, the output is always the inverse, and the original input may be
  1208. # destroyed, depending on the `overwrite_a` kwarg value
  1209. a = a0.astype(float).copy(order='F')
  1210. a_inv = inv(a, **overwrite_kw)
  1211. assert_allclose(a_inv @ a0, np.eye(n), atol=1e-14)
  1212. overwrite_a = overwrite_kw.get("overwrite_a", False)
  1213. assert (a == a0).all() != overwrite_a
  1214. assert np.shares_memory(a, a_inv) == overwrite_a
  1215. @pytest.mark.parametrize(
  1216. "dtyp", [np.float16, np.float32, np.longdouble, np.clongdouble]
  1217. )
  1218. def test_dtypes(self, dtyp):
  1219. # backwards compat: inv(float16)->float32 ; inv(clongdouble)->complex128 etc
  1220. a = np.arange(4).reshape(2, 2).astype(dtyp)
  1221. a_inv = inv(a)
  1222. assert_allclose(a @ a_inv, np.eye(a.shape[0]), atol=100*np.finfo(a.dtype).eps)
  1223. dt_map = {
  1224. 'e': 'f', # float16 -> float32
  1225. 'f': 'f',
  1226. 'g': 'd', # longdouble -> float64
  1227. 'G': 'D' # clongdouble -> complex128
  1228. }
  1229. assert a_inv.dtype.char == dt_map[a.dtype.char]
  1230. def test_readonly(self):
  1231. a = np.eye(3)
  1232. a.flags.writeable = False
  1233. a_inv = inv(a)
  1234. assert_allclose(a_inv, a, atol=1e-14)
  1235. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  1236. def test_batch_core_1x1(self, dt):
  1237. a = np.arange(3*2, dtype=dt).reshape(3, 2, 1, 1) + 1
  1238. a_inv = inv(a)
  1239. assert a_inv.shape == a.shape
  1240. assert_allclose(a @ a_inv, 1.)
  1241. @parametrize_overwrite_arg
  1242. def test_batch_zero_stride(self, overwrite_kw):
  1243. a = np.arange(3*2*2, dtype=float).reshape(3, 2, 2)
  1244. aa = a[None, ...]
  1245. a_inv = inv(aa, **overwrite_kw)
  1246. assert a_inv.shape == aa.shape
  1247. assert_allclose(aa @ a_inv, np.broadcast_to(np.eye(2), aa.shape), atol=2e-14)
  1248. aa = a[:, None, ...]
  1249. a_inv = inv(aa, **overwrite_kw)
  1250. assert a_inv.shape == aa.shape
  1251. assert_allclose(aa @ a_inv, np.broadcast_to(np.eye(2), aa.shape), atol=2e-14)
  1252. @parametrize_overwrite_arg
  1253. def test_batch_negative_stride(self, overwrite_kw):
  1254. a = np.arange(3*8).reshape(2, 3, 2, 2)
  1255. a = a[:, ::-1, :, :]
  1256. a_inv = inv(a, **overwrite_kw)
  1257. assert a_inv.shape == a.shape
  1258. assert_allclose(a @ a_inv, np.broadcast_to(np.eye(2), a.shape), atol=5e-14)
  1259. @parametrize_overwrite_arg
  1260. def test_core_negative_stride(self, overwrite_kw):
  1261. a = np.arange(3*8).reshape(2, 3, 2, 2)
  1262. a = a[:, :, ::-1, :]
  1263. a_inv = inv(a, **overwrite_kw)
  1264. assert a_inv.shape == a.shape
  1265. assert_allclose(a @ a_inv, np.broadcast_to(np.eye(2), a.shape), atol=5e-14)
  1266. @parametrize_overwrite_arg
  1267. def test_core_non_contiguous(self, overwrite_kw):
  1268. a = np.arange(3*8*2).reshape(2, 3, 2, 4)
  1269. a = a[..., ::2]
  1270. a_inv = inv(a, **overwrite_kw)
  1271. assert a_inv.shape == (2, 3, 2, 2)
  1272. assert_allclose(a @ a_inv, np.broadcast_to(np.eye(2), a.shape), atol=5e-14)
  1273. @parametrize_overwrite_arg
  1274. def test_batch_non_contiguous(self, overwrite_kw):
  1275. a = np.arange(3*8*2).reshape(2, 6, 2, 2)
  1276. a = a[:, ::2, ...]
  1277. a_inv = inv(a, **overwrite_kw)
  1278. assert a_inv.shape == (2, 3, 2, 2)
  1279. assert_allclose(a @ a_inv, np.broadcast_to(np.eye(2), a.shape), atol=2e-13)
  1280. @parametrize_overwrite_arg
  1281. def test_singular(self, overwrite_kw):
  1282. # 2D case: A singular matrix: raise
  1283. with assert_raises(LinAlgError):
  1284. inv(np.ones((2, 2)))
  1285. # batched case: If all slices are singlar, raise
  1286. with assert_raises(LinAlgError):
  1287. inv(np.ones((3, 2, 2)))
  1288. # XXX: shall we make this behavior configurable somehow?
  1289. # A "keep-going" option would be this:
  1290. # if some of the slices are singular and some are not,
  1291. # - singular slices are filled with nans
  1292. # - non-singular slices are inverted
  1293. # - there is no error
  1294. a = np.stack((np.ones((2, 2), dtype=complex), np.arange(4).reshape(2, 2)))
  1295. with assert_raises(LinAlgError):
  1296. inv(a)
  1297. # this would be true for a "keep-going" option
  1298. # assert np.isnan(a_inv[0, ...]).all()
  1299. # assert_allclose(a_inv[1, ...] @ a[1, ...], np.eye(2), atol=1e-14)
  1300. def test_ill_cond(self):
  1301. a = np.diag([1., 1e-20])
  1302. with pytest.warns(LinAlgWarning):
  1303. inv(a)
  1304. a2 = np.stack([np.diag([1., 1e-20]), np.diag([1, 1]), np.diag([1, 1e-20])])
  1305. with pytest.warns(LinAlgWarning):
  1306. inv(a2)
  1307. def test_wrong_assume_a(self):
  1308. with assert_raises(KeyError):
  1309. inv(np.eye(2), assume_a="kaboom")
  1310. def test_posdef(self):
  1311. x = np.arange(25, dtype=float).reshape(5, 5)
  1312. y = x + x.T
  1313. y += 21*np.eye(5)
  1314. y_inv0 = inv(y)
  1315. y_inv1 = inv(y, assume_a="pos")
  1316. assert_allclose(y_inv1, y_inv0, atol=1e-15)
  1317. # check that the lower triangle is not referenced for `lower=False`
  1318. mask = np.where(1 - np.tri(*y.shape, -1) == 0, np.nan, 1)
  1319. y_inv2 = inv(y*mask, check_finite=False, assume_a="pos", lower=False)
  1320. assert_allclose(y_inv2, y_inv0, atol=1e-15)
  1321. # repeat with the upper triangle
  1322. y_inv3 = inv(y*mask.T, check_finite=False, assume_a="pos", lower=True)
  1323. assert_allclose(y_inv3, y_inv0, atol=1e-15)
  1324. @pytest.mark.parametrize('complex_', [False, True])
  1325. def test_posdef_not_posdef(self, complex_):
  1326. # the `b` matrix is invertible but not pos definite: test the "sym" fallback
  1327. a = np.arange(9).reshape(3, 3)
  1328. b = a + a.T + np.eye(3)
  1329. if complex_:
  1330. b = b + 1j*b
  1331. # cholesky solver fails, and the routine falls back to the symmetric inverse
  1332. b_inv0 = inv(b)
  1333. assert_allclose(b_inv0 @ b, np.eye(3), atol=3e-15)
  1334. # but it does not fall back if `assume_a` is given
  1335. with assert_raises(LinAlgError):
  1336. inv(b, assume_a='pos')
  1337. # test posdef fallback to the hermitian solver, too
  1338. if complex_:
  1339. a = np.arange(9).reshape(3, 3)
  1340. a = a + 1j*a
  1341. b = a + a.T.conj() + np.eye(3)
  1342. assert_allclose(inv(b) @ b, np.eye(3), atol=3e-15)
  1343. def test_pos_fails_sym_complex(self):
  1344. # regression test for gh-24359
  1345. # the matrix is 1) symmetric not hermitian, and 2) not positive definite:
  1346. a = np.asarray([[ 182.56985285-64.28859483j, -177.24879835+11.0780499j ],
  1347. [-177.24879835+11.0780499j , 177.24879835-11.0780499j ]])
  1348. ainv = inv(a)
  1349. assert_allclose(ainv @ a, np.eye(2), atol=1e-14)
  1350. ainv_sym = inv(a, assume_a="sym")
  1351. assert_allclose(ainv_sym, ainv, atol=1e-14)
  1352. # Specifying assume_a="pos" disables the structure detection, and directly
  1353. # calls LAPACK routines zportf and zpotri.
  1354. # Since zportf(a) does not error out, neither does inv
  1355. ainv_chol = inv(a, assume_a="pos")
  1356. assert not np.allclose(ainv, ainv_chol, atol=1e-14)
  1357. # Setting assume_a="pos" with a non-pos def matrix returned nonsense.
  1358. # This is at least consistent with solve.
  1359. ainv_slv = solve(a, np.eye(2), assume_a="pos")
  1360. assert_allclose(ainv_chol, ainv_slv, atol=1e-14)
  1361. # Repeat it for bunch of simple cases to cover more branches
  1362. # Real symmetric, positive definite
  1363. a = np.eye(4) + np.ones(4)
  1364. res = inv(a)
  1365. assert_allclose(res @ a, np.eye(4), atol=1e-14)
  1366. # Real symmetric, NOT positive definite
  1367. a = -np.eye(4) + np.ones(4)
  1368. res = inv(a)
  1369. assert_allclose(res @ a, np.eye(4), atol=1e-14)
  1370. # Real, not symmetric
  1371. a = -np.eye(4) + np.ones(4)
  1372. a[0, -1] = 2.
  1373. res = inv(a)
  1374. assert_allclose(res @ a, np.eye(4), atol=1e-14)
  1375. # | Test | is_symm | is_herm | pos def |
  1376. # |---------------------------------------|---------|---------|---------|
  1377. # | Complex, both sym+herm, pos def | 1 | 1 | yes |
  1378. # | Complex, symmetric only | 1 | 0 | - |
  1379. # | Complex, both sym+herm, NOT pos def | 1 | 1 | no |
  1380. # | Complex, neither | 0 | 0 | - |
  1381. # | Complex, hermitian only, pos def | 0 | 1 | yes |
  1382. # | Complex, hermitian only, NOT pos def | 0 | 1 | no |
  1383. # Complex, both symmetric and hermitian, positive definite
  1384. a = (np.eye(4) + np.ones(4)).astype(np.complex128)
  1385. res = inv(a)
  1386. assert_allclose(res @ a, np.eye(4), atol=1e-14)
  1387. # Complex, symmetric only (not hermitian)
  1388. a = (np.eye(4)*1.0j + np.ones(4)).astype(np.complex128)
  1389. res = inv(a)
  1390. assert_allclose(res @ a, np.eye(4), atol=1e-14)
  1391. # Complex, both symmetric and hermitian, NOT positive definite
  1392. a = (-np.eye(4) + np.ones(4)).astype(np.complex128)
  1393. res = inv(a)
  1394. assert_allclose(res @ a, np.eye(4), atol=1e-14)
  1395. # Complex, neither symmetric nor hermitian
  1396. a = (-np.eye(4) + np.ones(4)).astype(np.complex128)
  1397. a[0, -1] = 2.
  1398. res = inv(a)
  1399. assert_allclose(res @ a, np.eye(4), atol=1e-14)
  1400. # Complex, hermitian only, positive definite
  1401. a = np.array([[2, 1+1j], [1-1j, 2]], dtype=np.complex128)
  1402. res = inv(a)
  1403. assert_allclose(res @ a, np.eye(2), atol=1e-14)
  1404. # Complex, hermitian only, NOT positive definite
  1405. a = np.array([[-1, 1+1j], [1-1j, -1]], dtype=np.complex128)
  1406. res = inv(a)
  1407. assert_allclose(res @ a, np.eye(2), atol=1e-14)
  1408. @pytest.mark.parametrize('complex_', [False, True])
  1409. @pytest.mark.parametrize('sym_herm', ['sym', 'her'])
  1410. def test_sym_her(self, complex_, sym_herm):
  1411. # test "sym" and "her" modes
  1412. a = np.arange(9).reshape(3, 3)
  1413. if complex_:
  1414. a = a + 1j*a
  1415. if sym_herm == "sym":
  1416. b = a + a.T
  1417. else: # sym_herm == "herm":
  1418. b = a + a.T.conj()
  1419. b = b + np.eye(3)
  1420. b_inv0 = np.linalg.inv(b)
  1421. assert_allclose(b_inv0 @ b, np.eye(3), atol=1e-14)
  1422. b_inv1 = inv(b, assume_a=sym_herm)
  1423. assert_allclose(b_inv0, b_inv1, atol=1e-15)
  1424. # check that the "other" triangle is not referenced
  1425. mask = np.where(1 - np.tri(*a.shape, -1) == 0, np.nan, 1)
  1426. b_inv2 = inv(b*mask, check_finite=False, assume_a=sym_herm, lower=False)
  1427. assert_allclose(b_inv2, b_inv0, atol=1e-15)
  1428. # repeat with the upper triangle
  1429. b_inv3 = inv(b*mask.T, check_finite=False, assume_a=sym_herm, lower=True)
  1430. assert_allclose(b_inv3, b_inv0, atol=1e-15)
  1431. def test_triangular_1(self):
  1432. x = np.arange(25, dtype=float).reshape(5, 5)
  1433. y = x + x.T
  1434. y += 21*np.eye(5)
  1435. y_inv0 = inv(y, assume_a='upper triangular')
  1436. # check that upper triangular differs from posdef
  1437. y_inv_posdef = inv(y, assume_a='pos')
  1438. assert not np.allclose(y_inv0, y_inv_posdef)
  1439. def test_triangular_2(self):
  1440. y = np.ones(25, dtype=float).reshape(5, 5)
  1441. y_inv_0_u = inv(np.triu(y))
  1442. assert_allclose(y_inv_0_u @ np.triu(y), np.eye(5), atol=1e-15)
  1443. y_inv_1_u = inv(y, assume_a='upper triangular')
  1444. assert_allclose(y_inv_1_u @ np.triu(y), np.eye(5), atol=1e-15)
  1445. # check that the lower triangle is not referenced for "upper triangular"
  1446. mask = np.where(1 - np.tri(*y.shape, -1) == 0, np.nan, 1)
  1447. y_inv_2_u = inv(y*mask, check_finite=False, assume_a='upper triangular')
  1448. assert_allclose(y_inv_2_u @ np.triu(y), np.eye(5), atol=1e-15)
  1449. # repeat for the lower traingular matrix
  1450. y_inv_0_l = inv(np.tril(y))
  1451. assert_allclose(y_inv_0_l @ np.tril(y), np.eye(5), atol=1e-15)
  1452. y_inv_1_l = inv(y, assume_a='lower triangular')
  1453. assert_allclose(y_inv_1_l @ np.tril(y), np.eye(5), atol=1e-15)
  1454. # check that the lower triangle is not referenced for "lower triangular"
  1455. mask = np.where(1 - np.tri(*y.shape, -1) == 0, np.nan, 1)
  1456. y_inv_2_l = inv(y*mask.T, check_finite=False, assume_a='lower triangular')
  1457. assert_allclose(y_inv_2_l @ np.tril(y), np.eye(5), atol=1e-15)
  1458. def test_diagonal(self):
  1459. a = np.stack([np.triu(np.ones((3, 3))), np.diag(np.arange(1, 4))])
  1460. inv_a = inv(a)
  1461. # basic diagonal invert
  1462. assert_allclose(inv_a[1], np.diag(1 / np.arange(1, 4)), atol=1e-14)
  1463. # ill-conditioned inputs warn
  1464. a = np.asarray([[1e30, 0], [0, 1]])
  1465. with pytest.warns(LinAlgWarning):
  1466. inv(a, assume_a="diagonal")
  1467. # singular input raises
  1468. a = np.asarray([[0, 0], [0, 1]])
  1469. with pytest.raises(LinAlgError):
  1470. inv(a, assume_a="diagonal")
  1471. class TestDet:
  1472. def test_1x1_all_singleton_dims(self):
  1473. a = np.array([[1]])
  1474. deta = det(a)
  1475. assert deta.dtype.char == 'd'
  1476. assert np.isscalar(deta)
  1477. assert deta == 1.
  1478. a = np.array([[[[1]]]], dtype='f')
  1479. deta = det(a)
  1480. assert deta.dtype.char == 'd'
  1481. assert deta.shape == (1, 1)
  1482. assert_equal(deta, [[1.0]])
  1483. a = np.array([[[1 + 3.j]]], dtype=np.complex64)
  1484. deta = det(a)
  1485. assert deta.dtype.char == 'D'
  1486. assert deta.shape == (1,)
  1487. assert_equal(deta, [1.+3.j])
  1488. def test_1by1_stacked_input_output(self):
  1489. rng = np.random.default_rng(1680305949878959)
  1490. a = rng.random([4, 5, 1, 1], dtype=np.float32)
  1491. deta = det(a)
  1492. assert deta.dtype.char == 'd'
  1493. assert deta.shape == (4, 5)
  1494. assert_allclose(deta, np.squeeze(a))
  1495. a = rng.random([4, 5, 1, 1], dtype=np.float32)*np.complex64(1.j)
  1496. deta = det(a)
  1497. assert deta.dtype.char == 'D'
  1498. assert deta.shape == (4, 5)
  1499. assert_allclose(deta, np.squeeze(a))
  1500. @pytest.mark.parametrize('shape', [[2, 2], [20, 20], [3, 2, 20, 20]])
  1501. def test_simple_det_shapes_real_complex(self, shape):
  1502. rng = np.random.default_rng(1680305949878959)
  1503. a = rng.uniform(-1., 1., size=shape)
  1504. d1, d2 = det(a), np.linalg.det(a)
  1505. assert_allclose(d1, d2)
  1506. b = rng.uniform(-1., 1., size=shape)*1j
  1507. b += rng.uniform(-0.5, 0.5, size=shape)
  1508. d3, d4 = det(b), np.linalg.det(b)
  1509. assert_allclose(d3, d4)
  1510. def test_for_known_det_values(self):
  1511. # Hadamard8
  1512. a = np.array([[1, 1, 1, 1, 1, 1, 1, 1],
  1513. [1, -1, 1, -1, 1, -1, 1, -1],
  1514. [1, 1, -1, -1, 1, 1, -1, -1],
  1515. [1, -1, -1, 1, 1, -1, -1, 1],
  1516. [1, 1, 1, 1, -1, -1, -1, -1],
  1517. [1, -1, 1, -1, -1, 1, -1, 1],
  1518. [1, 1, -1, -1, -1, -1, 1, 1],
  1519. [1, -1, -1, 1, -1, 1, 1, -1]])
  1520. assert_allclose(det(a), 4096.)
  1521. # consecutive number array always singular
  1522. assert_allclose(det(np.arange(25).reshape(5, 5)), 0.)
  1523. # simple anti-diagonal block array
  1524. # Upper right has det (-2+1j) and lower right has (-2-1j)
  1525. # det(a) = - (-2+1j) (-2-1j) = 5.
  1526. a = np.array([[0.+0.j, 0.+0.j, 0.-1.j, 1.-1.j],
  1527. [0.+0.j, 0.+0.j, 1.+0.j, 0.-1.j],
  1528. [0.+1.j, 1.+1.j, 0.+0.j, 0.+0.j],
  1529. [1.+0.j, 0.+1.j, 0.+0.j, 0.+0.j]], dtype=np.complex64)
  1530. assert_allclose(det(a), 5.+0.j)
  1531. # Fiedler companion complexified
  1532. # >>> a = scipy.linalg.fiedler_companion(np.arange(1, 10))
  1533. a = np.array([[-2., -3., 1., 0., 0., 0., 0., 0.],
  1534. [1., 0., 0., 0., 0., 0., 0., 0.],
  1535. [0., -4., 0., -5., 1., 0., 0., 0.],
  1536. [0., 1., 0., 0., 0., 0., 0., 0.],
  1537. [0., 0., 0., -6., 0., -7., 1., 0.],
  1538. [0., 0., 0., 1., 0., 0., 0., 0.],
  1539. [0., 0., 0., 0., 0., -8., 0., -9.],
  1540. [0., 0., 0., 0., 0., 1., 0., 0.]])*1.j
  1541. assert_allclose(det(a), 9.)
  1542. # g and G dtypes are handled differently in windows and other platforms
  1543. @pytest.mark.parametrize('typ', [x for x in np.typecodes['All'][:20]
  1544. if x not in 'gG'])
  1545. def test_sample_compatible_dtype_input(self, typ):
  1546. rng = np.random.default_rng(1680305949878959)
  1547. n = 4
  1548. a = rng.random([n, n]).astype(typ) # value is not important
  1549. assert isinstance(det(a), (np.float64 | np.complex128))
  1550. def test_incompatible_dtype_input(self):
  1551. # Double backslashes needed for escaping pytest regex.
  1552. msg = 'cannot be cast to float\\(32, 64\\)'
  1553. for c, t in zip('SUO', ['bytes8', 'str32', 'object']):
  1554. with assert_raises(TypeError, match=msg):
  1555. det(np.array([['a', 'b']]*2, dtype=c))
  1556. with assert_raises(TypeError, match=msg):
  1557. det(np.array([[b'a', b'b']]*2, dtype='V'))
  1558. with assert_raises(TypeError, match=msg):
  1559. det(np.array([[100, 200]]*2, dtype='datetime64[s]'))
  1560. with assert_raises(TypeError, match=msg):
  1561. det(np.array([[100, 200]]*2, dtype='timedelta64[s]'))
  1562. def test_empty_edge_cases(self):
  1563. assert_allclose(det(np.empty([0, 0])), 1.)
  1564. assert_allclose(det(np.empty([0, 0, 0])), np.array([]))
  1565. assert_allclose(det(np.empty([3, 0, 0])), np.array([1., 1., 1.]))
  1566. with assert_raises(ValueError, match='Last 2 dimensions'):
  1567. det(np.empty([0, 0, 3]))
  1568. with assert_raises(ValueError, match='at least two-dimensional'):
  1569. det(np.array([]))
  1570. with assert_raises(ValueError, match='Last 2 dimensions'):
  1571. det(np.array([[]]))
  1572. with assert_raises(ValueError, match='Last 2 dimensions'):
  1573. det(np.array([[[]]]))
  1574. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  1575. def test_empty_dtype(self, dt):
  1576. a = np.empty((0, 0), dtype=dt)
  1577. d = det(a)
  1578. assert d.shape == ()
  1579. assert d.dtype == det(np.eye(2, dtype=dt)).dtype
  1580. a = np.empty((3, 0, 0), dtype=dt)
  1581. d = det(a)
  1582. assert d.shape == (3,)
  1583. assert d.dtype == det(np.zeros((3, 1, 1), dtype=dt)).dtype
  1584. def test_overwrite_a(self):
  1585. # If all conditions are met then input should be overwritten;
  1586. # - dtype is one of 'fdFD'
  1587. # - C-contiguous
  1588. # - writeable
  1589. a = np.arange(9).reshape(3, 3).astype(np.float32)
  1590. ac = a.copy()
  1591. deta = det(ac, overwrite_a=True)
  1592. assert_allclose(deta, 0.)
  1593. assert not (a == ac).all()
  1594. def test_readonly_array(self):
  1595. a = np.array([[2., 0., 1.], [5., 3., -1.], [1., 1., 1.]])
  1596. a.setflags(write=False)
  1597. # overwrite_a will be overridden
  1598. assert_allclose(det(a, overwrite_a=True), 10.)
  1599. def test_simple_check_finite(self):
  1600. a = [[1, 2], [3, np.inf]]
  1601. with assert_raises(ValueError, match='array must not contain'):
  1602. det(a)
  1603. def direct_lstsq(a, b, cmplx=0):
  1604. at = transpose(a)
  1605. if cmplx:
  1606. at = conjugate(at)
  1607. a1 = dot(at, a)
  1608. b1 = dot(at, b)
  1609. return solve(a1, b1)
  1610. class TestLstsq:
  1611. lapack_drivers = ('gelsd', 'gelss', 'gelsy', None)
  1612. def test_simple_exact(self):
  1613. for dtype in REAL_DTYPES:
  1614. a = np.array([[1, 20], [-30, 4]], dtype=dtype)
  1615. for lapack_driver in TestLstsq.lapack_drivers:
  1616. for overwrite in (True, False):
  1617. for bt in (((1, 0), (0, 1)), (1, 0),
  1618. ((2, 1), (-30, 4))):
  1619. # Store values in case they are overwritten
  1620. # later
  1621. a1 = a.copy()
  1622. b = np.array(bt, dtype=dtype)
  1623. b1 = b.copy()
  1624. out = lstsq(a1, b1,
  1625. lapack_driver=lapack_driver,
  1626. overwrite_a=overwrite,
  1627. overwrite_b=overwrite)
  1628. x = out[0]
  1629. r = out[2]
  1630. assert_(r == 2,
  1631. f'expected efficient rank 2, got {r}')
  1632. assert_allclose(dot(a, x), b,
  1633. atol=25 * _eps_cast(a1.dtype),
  1634. rtol=25 * _eps_cast(a1.dtype),
  1635. err_msg=f"driver: {lapack_driver}")
  1636. def test_simple_overdet(self):
  1637. for dtype in REAL_DTYPES:
  1638. a = np.array([[1, 2], [4, 5], [3, 4]], dtype=dtype)
  1639. b = np.array([1, 2, 3], dtype=dtype)
  1640. for lapack_driver in TestLstsq.lapack_drivers:
  1641. for overwrite in (True, False):
  1642. # Store values in case they are overwritten later
  1643. a1 = a.copy()
  1644. b1 = b.copy()
  1645. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  1646. overwrite_a=overwrite,
  1647. overwrite_b=overwrite)
  1648. x = out[0]
  1649. if lapack_driver == 'gelsy':
  1650. residuals = np.sum((b - a.dot(x))**2)
  1651. else:
  1652. residuals = out[1]
  1653. r = out[2]
  1654. assert_(r == 2, f'expected efficient rank 2, got {r}')
  1655. assert_allclose(abs((dot(a, x) - b)**2).sum(axis=0),
  1656. residuals,
  1657. rtol=25 * _eps_cast(a1.dtype),
  1658. atol=25 * _eps_cast(a1.dtype),
  1659. err_msg=f"driver: {lapack_driver}")
  1660. assert_allclose(x, (-0.428571428571429, 0.85714285714285),
  1661. rtol=25 * _eps_cast(a1.dtype),
  1662. atol=25 * _eps_cast(a1.dtype),
  1663. err_msg=f"driver: {lapack_driver}")
  1664. def test_simple_overdet_complex(self):
  1665. for dtype in COMPLEX_DTYPES:
  1666. a = np.array([[1+2j, 2], [4, 5], [3, 4]], dtype=dtype)
  1667. b = np.array([1, 2+4j, 3], dtype=dtype)
  1668. for lapack_driver in TestLstsq.lapack_drivers:
  1669. for overwrite in (True, False):
  1670. # Store values in case they are overwritten later
  1671. a1 = a.copy()
  1672. b1 = b.copy()
  1673. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  1674. overwrite_a=overwrite,
  1675. overwrite_b=overwrite)
  1676. x = out[0]
  1677. if lapack_driver == 'gelsy':
  1678. res = b - a.dot(x)
  1679. residuals = np.sum(res * res.conj())
  1680. else:
  1681. residuals = out[1]
  1682. r = out[2]
  1683. assert_(r == 2, f'expected efficient rank 2, got {r}')
  1684. assert_allclose(abs((dot(a, x) - b)**2).sum(axis=0),
  1685. residuals,
  1686. rtol=25 * _eps_cast(a1.dtype),
  1687. atol=25 * _eps_cast(a1.dtype),
  1688. err_msg=f"driver: {lapack_driver}")
  1689. assert_allclose(
  1690. x, (-0.4831460674157303 + 0.258426966292135j,
  1691. 0.921348314606741 + 0.292134831460674j),
  1692. rtol=25 * _eps_cast(a1.dtype),
  1693. atol=25 * _eps_cast(a1.dtype),
  1694. err_msg=f"driver: {lapack_driver}")
  1695. def test_simple_underdet(self):
  1696. for dtype in REAL_DTYPES:
  1697. a = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
  1698. b = np.array([1, 2], dtype=dtype)
  1699. for lapack_driver in TestLstsq.lapack_drivers:
  1700. for overwrite in (True, False):
  1701. # Store values in case they are overwritten later
  1702. a1 = a.copy()
  1703. b1 = b.copy()
  1704. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  1705. overwrite_a=overwrite,
  1706. overwrite_b=overwrite)
  1707. x = out[0]
  1708. r = out[2]
  1709. assert_(r == 2, f'expected efficient rank 2, got {r}')
  1710. assert_allclose(x, (-0.055555555555555, 0.111111111111111,
  1711. 0.277777777777777),
  1712. rtol=25 * _eps_cast(a1.dtype),
  1713. atol=25 * _eps_cast(a1.dtype),
  1714. err_msg=f"driver: {lapack_driver}")
  1715. @pytest.mark.parametrize("dtype", REAL_DTYPES)
  1716. @pytest.mark.parametrize("n", (20, 200))
  1717. @pytest.mark.parametrize("lapack_driver", lapack_drivers)
  1718. @pytest.mark.parametrize("overwrite", (True, False))
  1719. def test_random_exact(self, dtype, n, lapack_driver, overwrite):
  1720. rng = np.random.RandomState(1234)
  1721. a = np.asarray(rng.random([n, n]), dtype=dtype)
  1722. for i in range(n):
  1723. a[i, i] = 20 * (0.1 + a[i, i])
  1724. for i in range(4):
  1725. b = np.asarray(rng.random([n, 3]), dtype=dtype)
  1726. # Store values in case they are overwritten later
  1727. a1 = a.copy()
  1728. b1 = b.copy()
  1729. out = lstsq(a1, b1,
  1730. lapack_driver=lapack_driver,
  1731. overwrite_a=overwrite,
  1732. overwrite_b=overwrite)
  1733. x = out[0]
  1734. r = out[2]
  1735. assert_(r == n, f'expected efficient rank {n}, '
  1736. f'got {r}')
  1737. if dtype is np.float32:
  1738. assert_allclose(
  1739. dot(a, x), b,
  1740. rtol=500 * _eps_cast(a1.dtype),
  1741. atol=500 * _eps_cast(a1.dtype),
  1742. err_msg=f"driver: {lapack_driver}")
  1743. else:
  1744. assert_allclose(
  1745. dot(a, x), b,
  1746. rtol=1000 * _eps_cast(a1.dtype),
  1747. atol=1000 * _eps_cast(a1.dtype),
  1748. err_msg=f"driver: {lapack_driver}")
  1749. @pytest.mark.skipif(IS_MUSL, reason="may segfault on Alpine, see gh-17630")
  1750. @pytest.mark.parametrize("dtype", COMPLEX_DTYPES)
  1751. @pytest.mark.parametrize("n", (20, 200))
  1752. @pytest.mark.parametrize("lapack_driver", lapack_drivers)
  1753. @pytest.mark.parametrize("overwrite", (True, False))
  1754. def test_random_complex_exact(self, dtype, n, lapack_driver, overwrite):
  1755. rng = np.random.RandomState(1234)
  1756. a = np.asarray(rng.random([n, n]) + 1j*rng.random([n, n]),
  1757. dtype=dtype)
  1758. for i in range(n):
  1759. a[i, i] = 20 * (0.1 + a[i, i])
  1760. for i in range(2):
  1761. b = np.asarray(rng.random([n, 3]), dtype=dtype)
  1762. # Store values in case they are overwritten later
  1763. a1 = a.copy()
  1764. b1 = b.copy()
  1765. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  1766. overwrite_a=overwrite,
  1767. overwrite_b=overwrite)
  1768. x = out[0]
  1769. r = out[2]
  1770. assert_(r == n, f'expected efficient rank {n}, '
  1771. f'got {r}')
  1772. if dtype is np.complex64:
  1773. assert_allclose(
  1774. dot(a, x), b,
  1775. rtol=400 * _eps_cast(a1.dtype),
  1776. atol=400 * _eps_cast(a1.dtype),
  1777. err_msg=f"driver: {lapack_driver}")
  1778. else:
  1779. assert_allclose(
  1780. dot(a, x), b,
  1781. rtol=1000 * _eps_cast(a1.dtype),
  1782. atol=1000 * _eps_cast(a1.dtype),
  1783. err_msg=f"driver: {lapack_driver}")
  1784. def test_random_overdet(self):
  1785. rng = np.random.RandomState(1234)
  1786. for dtype in REAL_DTYPES:
  1787. for (n, m) in ((20, 15), (200, 2)):
  1788. for lapack_driver in TestLstsq.lapack_drivers:
  1789. for overwrite in (True, False):
  1790. a = np.asarray(rng.random([n, m]), dtype=dtype)
  1791. for i in range(m):
  1792. a[i, i] = 20 * (0.1 + a[i, i])
  1793. for i in range(4):
  1794. b = np.asarray(rng.random([n, 3]), dtype=dtype)
  1795. # Store values in case they are overwritten later
  1796. a1 = a.copy()
  1797. b1 = b.copy()
  1798. out = lstsq(a1, b1,
  1799. lapack_driver=lapack_driver,
  1800. overwrite_a=overwrite,
  1801. overwrite_b=overwrite)
  1802. x = out[0]
  1803. r = out[2]
  1804. assert_(r == m, f'expected efficient rank {m}, '
  1805. f'got {r}')
  1806. assert_allclose(
  1807. x, direct_lstsq(a, b, cmplx=0),
  1808. rtol=25 * _eps_cast(a1.dtype),
  1809. atol=25 * _eps_cast(a1.dtype),
  1810. err_msg=f"driver: {lapack_driver}")
  1811. def test_random_complex_overdet(self):
  1812. rng = np.random.RandomState(1234)
  1813. for dtype in COMPLEX_DTYPES:
  1814. for (n, m) in ((20, 15), (200, 2)):
  1815. for lapack_driver in TestLstsq.lapack_drivers:
  1816. for overwrite in (True, False):
  1817. a = np.asarray(rng.random([n, m]) + 1j*rng.random([n, m]),
  1818. dtype=dtype)
  1819. for i in range(m):
  1820. a[i, i] = 20 * (0.1 + a[i, i])
  1821. for i in range(2):
  1822. b = np.asarray(rng.random([n, 3]), dtype=dtype)
  1823. # Store values in case they are overwritten
  1824. # later
  1825. a1 = a.copy()
  1826. b1 = b.copy()
  1827. out = lstsq(a1, b1,
  1828. lapack_driver=lapack_driver,
  1829. overwrite_a=overwrite,
  1830. overwrite_b=overwrite)
  1831. x = out[0]
  1832. r = out[2]
  1833. assert_(r == m, f'expected efficient rank {m}, '
  1834. f'got {r}')
  1835. assert_allclose(
  1836. x, direct_lstsq(a, b, cmplx=1),
  1837. rtol=25 * _eps_cast(a1.dtype),
  1838. atol=25 * _eps_cast(a1.dtype),
  1839. err_msg=f"driver: {lapack_driver}")
  1840. def test_check_finite(self):
  1841. with warnings.catch_warnings():
  1842. # On (some) OSX this tests triggers a warning (gh-7538)
  1843. warnings.filterwarnings("ignore",
  1844. "internal gelsd driver lwork query error,.*"
  1845. "Falling back to 'gelss' driver.", RuntimeWarning)
  1846. at = np.array(((1, 20), (-30, 4)))
  1847. for dtype, bt, lapack_driver, overwrite, check_finite in \
  1848. itertools.product(REAL_DTYPES,
  1849. (((1, 0), (0, 1)), (1, 0), ((2, 1), (-30, 4))),
  1850. TestLstsq.lapack_drivers,
  1851. (True, False),
  1852. (True, False)):
  1853. a = at.astype(dtype)
  1854. b = np.array(bt, dtype=dtype)
  1855. # Store values in case they are overwritten
  1856. # later
  1857. a1 = a.copy()
  1858. b1 = b.copy()
  1859. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  1860. check_finite=check_finite, overwrite_a=overwrite,
  1861. overwrite_b=overwrite)
  1862. x = out[0]
  1863. r = out[2]
  1864. assert_(r == 2, f'expected efficient rank 2, got {r}')
  1865. assert_allclose(dot(a, x), b,
  1866. rtol=25 * _eps_cast(a.dtype),
  1867. atol=25 * _eps_cast(a.dtype),
  1868. err_msg=f"driver: {lapack_driver}")
  1869. def test_empty(self):
  1870. for a_shape, b_shape in (((0, 2), (0,)),
  1871. ((0, 4), (0, 2)),
  1872. ((4, 0), (4,)),
  1873. ((4, 0), (4, 2))):
  1874. b = np.ones(b_shape)
  1875. x, residues, rank, s = lstsq(np.zeros(a_shape), b)
  1876. assert_equal(x, np.zeros((a_shape[1],) + b_shape[1:]))
  1877. residues_should_be = (np.empty((0,)) if a_shape[1]
  1878. else np.linalg.norm(b, axis=0)**2)
  1879. assert_equal(residues, residues_should_be)
  1880. assert_(rank == 0, 'expected rank 0')
  1881. assert_equal(s, np.empty((0,)))
  1882. @pytest.mark.parametrize('dt_a', [int, float, np.float32, complex, np.complex64])
  1883. @pytest.mark.parametrize('dt_b', [int, float, np.float32, complex, np.complex64])
  1884. def test_empty_dtype(self, dt_a, dt_b):
  1885. a = np.empty((0, 0), dtype=dt_a)
  1886. b = np.empty(0, dtype=dt_b)
  1887. x, residues, rank, s = lstsq(a, b)
  1888. assert x.size == 0
  1889. dt_nonempty = lstsq(np.eye(2, dtype=dt_a), np.ones(2, dtype=dt_b))[0].dtype
  1890. assert x.dtype == dt_nonempty
  1891. class TestPinv:
  1892. def test_simple_real(self):
  1893. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=float)
  1894. a_pinv = pinv(a)
  1895. assert_array_almost_equal(dot(a, a_pinv), np.eye(3))
  1896. def test_simple_complex(self):
  1897. a = (array([[1, 2, 3], [4, 5, 6], [7, 8, 10]],
  1898. dtype=float) + 1j * array([[10, 8, 7], [6, 5, 4], [3, 2, 1]],
  1899. dtype=float))
  1900. a_pinv = pinv(a)
  1901. assert_array_almost_equal(dot(a, a_pinv), np.eye(3))
  1902. def test_simple_singular(self):
  1903. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=float)
  1904. a_pinv = pinv(a)
  1905. expected = array([[-6.38888889e-01, -1.66666667e-01, 3.05555556e-01],
  1906. [-5.55555556e-02, 1.30136518e-16, 5.55555556e-02],
  1907. [5.27777778e-01, 1.66666667e-01, -1.94444444e-01]])
  1908. assert_array_almost_equal(a_pinv, expected)
  1909. def test_simple_cols(self):
  1910. a = array([[1, 2, 3], [4, 5, 6]], dtype=float)
  1911. a_pinv = pinv(a)
  1912. expected = array([[-0.94444444, 0.44444444],
  1913. [-0.11111111, 0.11111111],
  1914. [0.72222222, -0.22222222]])
  1915. assert_array_almost_equal(a_pinv, expected)
  1916. def test_simple_rows(self):
  1917. a = array([[1, 2], [3, 4], [5, 6]], dtype=float)
  1918. a_pinv = pinv(a)
  1919. expected = array([[-1.33333333, -0.33333333, 0.66666667],
  1920. [1.08333333, 0.33333333, -0.41666667]])
  1921. assert_array_almost_equal(a_pinv, expected)
  1922. def test_check_finite(self):
  1923. a = array([[1, 2, 3], [4, 5, 6.], [7, 8, 10]])
  1924. a_pinv = pinv(a, check_finite=False)
  1925. assert_array_almost_equal(dot(a, a_pinv), np.eye(3))
  1926. def test_native_list_argument(self):
  1927. a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
  1928. a_pinv = pinv(a)
  1929. expected = array([[-6.38888889e-01, -1.66666667e-01, 3.05555556e-01],
  1930. [-5.55555556e-02, 1.30136518e-16, 5.55555556e-02],
  1931. [5.27777778e-01, 1.66666667e-01, -1.94444444e-01]])
  1932. assert_array_almost_equal(a_pinv, expected)
  1933. def test_atol_rtol(self):
  1934. rng = np.random.default_rng(1234)
  1935. n = 12
  1936. # get a random ortho matrix for shuffling
  1937. q, _ = qr(rng.random((n, n)))
  1938. a_m = np.arange(35.0).reshape(7, 5)
  1939. a = a_m.copy()
  1940. a[0, 0] = 0.001
  1941. atol = 1e-5
  1942. rtol = 0.05
  1943. # svds of a_m is ~ [116.906, 4.234, tiny, tiny, tiny]
  1944. # svds of a is ~ [116.906, 4.234, 4.62959e-04, tiny, tiny]
  1945. # Just abs cutoff such that we arrive at a_modified
  1946. a_p = pinv(a_m, atol=atol, rtol=0.)
  1947. adiff1 = a @ a_p @ a - a
  1948. adiff2 = a_m @ a_p @ a_m - a_m
  1949. # Now adiff1 should be around atol value while adiff2 should be
  1950. # relatively tiny
  1951. assert_allclose(np.linalg.norm(adiff1), 5e-4, atol=5.e-4)
  1952. assert_allclose(np.linalg.norm(adiff2), 5e-14, atol=5.e-14)
  1953. # Now do the same but remove another sv ~4.234 via rtol
  1954. a_p = pinv(a_m, atol=atol, rtol=rtol)
  1955. adiff1 = a @ a_p @ a - a
  1956. adiff2 = a_m @ a_p @ a_m - a_m
  1957. assert_allclose(np.linalg.norm(adiff1), 4.233, rtol=0.01)
  1958. assert_allclose(np.linalg.norm(adiff2), 4.233, rtol=0.01)
  1959. @pytest.mark.parametrize('dt', [float, np.float32, complex, np.complex64])
  1960. def test_empty(self, dt):
  1961. a = np.empty((0, 0), dtype=dt)
  1962. a_pinv = pinv(a)
  1963. assert a_pinv.size == 0
  1964. assert a_pinv.dtype == pinv(np.eye(2, dtype=dt)).dtype
  1965. class TestPinvSymmetric:
  1966. def test_simple_real(self):
  1967. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=float)
  1968. a = np.dot(a, a.T)
  1969. a_pinv = pinvh(a)
  1970. assert_array_almost_equal(np.dot(a, a_pinv), np.eye(3))
  1971. def test_nonpositive(self):
  1972. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=float)
  1973. a = np.dot(a, a.T)
  1974. u, s, vt = np.linalg.svd(a)
  1975. s[0] *= -1
  1976. a = np.dot(u * s, vt) # a is now symmetric non-positive and singular
  1977. a_pinv = pinv(a)
  1978. a_pinvh = pinvh(a)
  1979. assert_array_almost_equal(a_pinv, a_pinvh)
  1980. def test_simple_complex(self):
  1981. a = (array([[1, 2, 3], [4, 5, 6], [7, 8, 10]],
  1982. dtype=float) + 1j * array([[10, 8, 7], [6, 5, 4], [3, 2, 1]],
  1983. dtype=float))
  1984. a = np.dot(a, a.conj().T)
  1985. a_pinv = pinvh(a)
  1986. assert_array_almost_equal(np.dot(a, a_pinv), np.eye(3))
  1987. def test_native_list_argument(self):
  1988. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=float)
  1989. a = np.dot(a, a.T)
  1990. a_pinv = pinvh(a.tolist())
  1991. assert_array_almost_equal(np.dot(a, a_pinv), np.eye(3))
  1992. def test_zero_eigenvalue(self):
  1993. # https://github.com/scipy/scipy/issues/12515
  1994. # the SYEVR eigh driver may give the zero eigenvalue > eps
  1995. a = np.array([[1, -1, 0], [-1, 2, -1], [0, -1, 1]])
  1996. p = pinvh(a)
  1997. assert_allclose(p @ a @ p, p, atol=1e-15)
  1998. assert_allclose(a @ p @ a, a, atol=1e-15)
  1999. def test_atol_rtol(self):
  2000. rng = np.random.default_rng(1234)
  2001. n = 12
  2002. # get a random ortho matrix for shuffling
  2003. q, _ = qr(rng.random((n, n)))
  2004. a = np.diag([4, 3, 2, 1, 0.99e-4, 0.99e-5] + [0.99e-6]*(n-6))
  2005. a = q.T @ a @ q
  2006. a_m = np.diag([4, 3, 2, 1, 0.99e-4, 0.] + [0.]*(n-6))
  2007. a_m = q.T @ a_m @ q
  2008. atol = 1e-5
  2009. rtol = (4.01e-4 - 4e-5)/4
  2010. # Just abs cutoff such that we arrive at a_modified
  2011. a_p = pinvh(a, atol=atol, rtol=0.)
  2012. adiff1 = a @ a_p @ a - a
  2013. adiff2 = a_m @ a_p @ a_m - a_m
  2014. # Now adiff1 should dance around atol value since truncation
  2015. # while adiff2 should be relatively tiny
  2016. assert_allclose(norm(adiff1), atol, rtol=0.1)
  2017. assert_allclose(norm(adiff2), 1e-12, atol=1e-11)
  2018. # Now do the same but through rtol cancelling atol value
  2019. a_p = pinvh(a, atol=atol, rtol=rtol)
  2020. adiff1 = a @ a_p @ a - a
  2021. adiff2 = a_m @ a_p @ a_m - a_m
  2022. # adiff1 and adiff2 should be elevated to ~1e-4 due to mismatch
  2023. assert_allclose(norm(adiff1), 1e-4, rtol=0.1)
  2024. assert_allclose(norm(adiff2), 1e-4, rtol=0.1)
  2025. @pytest.mark.parametrize('dt', [float, np.float32, complex, np.complex64])
  2026. def test_empty(self, dt):
  2027. a = np.empty((0, 0), dtype=dt)
  2028. a_pinv = pinvh(a)
  2029. assert a_pinv.size == 0
  2030. assert a_pinv.dtype == pinv(np.eye(2, dtype=dt)).dtype
  2031. @pytest.mark.parametrize('scale', (1e-20, 1., 1e20))
  2032. @pytest.mark.parametrize('pinv_', (pinv, pinvh))
  2033. def test_auto_rcond(scale, pinv_):
  2034. x = np.array([[1, 0], [0, 1e-10]]) * scale
  2035. expected = np.diag(1. / np.diag(x))
  2036. x_inv = pinv_(x)
  2037. assert_allclose(x_inv, expected)
  2038. class TestVectorNorms:
  2039. def test_types(self):
  2040. for dtype in np.typecodes['AllFloat']:
  2041. x = np.array([1, 2, 3], dtype=dtype)
  2042. tol = max(1e-15, np.finfo(dtype).eps.real * 20)
  2043. assert_allclose(norm(x), np.sqrt(14), rtol=tol)
  2044. assert_allclose(norm(x, 2), np.sqrt(14), rtol=tol)
  2045. for dtype in np.typecodes['Complex']:
  2046. x = np.array([1j, 2j, 3j], dtype=dtype)
  2047. tol = max(1e-15, np.finfo(dtype).eps.real * 20)
  2048. assert_allclose(norm(x), np.sqrt(14), rtol=tol)
  2049. assert_allclose(norm(x, 2), np.sqrt(14), rtol=tol)
  2050. def test_overflow(self):
  2051. # unlike numpy's norm, this one is
  2052. # safer on overflow
  2053. a = array([1e20], dtype=float32)
  2054. assert_almost_equal(norm(a), a)
  2055. def test_stable(self):
  2056. # more stable than numpy's norm
  2057. a = array([1e4] + [1]*10000, dtype=float32)
  2058. try:
  2059. # snrm in double precision; we obtain the same as for float64
  2060. # -- large atol needed due to varying blas implementations
  2061. assert_allclose(norm(a) - 1e4, 0.5, atol=1e-2)
  2062. except AssertionError:
  2063. # snrm implemented in single precision, == np.linalg.norm result
  2064. msg = ": Result should equal either 0.0 or 0.5 (depending on " \
  2065. "implementation of snrm2)."
  2066. assert_almost_equal(norm(a) - 1e4, 0.0, err_msg=msg)
  2067. def test_zero_norm(self):
  2068. assert_equal(norm([1, 0, 3], 0), 2)
  2069. assert_equal(norm([1, 2, 3], 0), 3)
  2070. def test_axis_kwd(self):
  2071. a = np.array([[[2, 1], [3, 4]]] * 2, 'd')
  2072. assert_allclose(norm(a, axis=1), [[3.60555128, 4.12310563]] * 2)
  2073. assert_allclose(norm(a, 1, axis=1), [[5.] * 2] * 2)
  2074. def test_keepdims_kwd(self):
  2075. a = np.array([[[2, 1], [3, 4]]] * 2, 'd')
  2076. b = norm(a, axis=1, keepdims=True)
  2077. assert_allclose(b, [[[3.60555128, 4.12310563]]] * 2)
  2078. assert_(b.shape == (2, 1, 2))
  2079. assert_allclose(norm(a, 1, axis=2, keepdims=True), [[[3.], [7.]]] * 2)
  2080. @pytest.mark.skipif(not HAS_ILP64, reason="64-bit BLAS required")
  2081. def test_large_vector(self):
  2082. check_free_memory(free_mb=17000)
  2083. x = np.zeros([2**31], dtype=np.float64)
  2084. x[-1] = 1
  2085. res = norm(x)
  2086. del x
  2087. assert_allclose(res, 1.0)
  2088. class TestMatrixNorms:
  2089. def test_matrix_norms(self):
  2090. # Not all of these are matrix norms in the most technical sense.
  2091. rng = np.random.default_rng(1234)
  2092. for n, m in (1, 1), (1, 3), (3, 1), (4, 4), (4, 5), (5, 4):
  2093. for t in np.float32, np.float64, np.complex64, np.complex128, np.int64:
  2094. A = 10 * rng.standard_normal((n, m)).astype(t)
  2095. if np.issubdtype(A.dtype, np.complexfloating):
  2096. A += 10j * rng.standard_normal((n, m))
  2097. t_high = np.complex128
  2098. else:
  2099. t_high = np.float64
  2100. for order in (None, 'fro', 1, -1, 2, -2, np.inf, -np.inf):
  2101. actual = norm(A, ord=order)
  2102. desired = np.linalg.norm(A, ord=order)
  2103. # SciPy may return higher precision matrix norms.
  2104. # This is a consequence of using LAPACK.
  2105. if not np.allclose(actual, desired):
  2106. desired = np.linalg.norm(A.astype(t_high), ord=order)
  2107. assert_allclose(actual, desired)
  2108. def test_axis_kwd(self):
  2109. a = np.array([[[2, 1], [3, 4]]] * 2, 'd')
  2110. b = norm(a, ord=np.inf, axis=(1, 0))
  2111. c = norm(np.swapaxes(a, 0, 1), ord=np.inf, axis=(0, 1))
  2112. d = norm(a, ord=1, axis=(0, 1))
  2113. assert_allclose(b, c)
  2114. assert_allclose(c, d)
  2115. assert_allclose(b, d)
  2116. assert_(b.shape == c.shape == d.shape)
  2117. b = norm(a, ord=1, axis=(1, 0))
  2118. c = norm(np.swapaxes(a, 0, 1), ord=1, axis=(0, 1))
  2119. d = norm(a, ord=np.inf, axis=(0, 1))
  2120. assert_allclose(b, c)
  2121. assert_allclose(c, d)
  2122. assert_allclose(b, d)
  2123. assert_(b.shape == c.shape == d.shape)
  2124. def test_keepdims_kwd(self):
  2125. a = np.arange(120, dtype='d').reshape(2, 3, 4, 5)
  2126. b = norm(a, ord=np.inf, axis=(1, 0), keepdims=True)
  2127. c = norm(a, ord=1, axis=(0, 1), keepdims=True)
  2128. assert_allclose(b, c)
  2129. assert_(b.shape == c.shape)
  2130. def test_empty(self):
  2131. a = np.empty((0, 0))
  2132. assert_allclose(norm(a), 0.)
  2133. assert_allclose(norm(a, axis=0), np.zeros((0,)))
  2134. assert_allclose(norm(a, keepdims=True), np.zeros((1, 1)))
  2135. a = np.empty((0, 3))
  2136. assert_allclose(norm(a), 0.)
  2137. assert_allclose(norm(a, axis=0), np.zeros((3,)))
  2138. assert_allclose(norm(a, keepdims=True), np.zeros((1, 1)))
  2139. class TestOverwrite:
  2140. def test_solve(self):
  2141. assert_no_overwrite(solve, [(3, 3), (3,)])
  2142. def test_solve_triangular(self):
  2143. assert_no_overwrite(solve_triangular, [(3, 3), (3,)])
  2144. def test_solve_banded(self):
  2145. assert_no_overwrite(lambda ab, b: solve_banded((2, 1), ab, b),
  2146. [(4, 6), (6,)])
  2147. def test_solveh_banded(self):
  2148. assert_no_overwrite(solveh_banded, [(2, 6), (6,)])
  2149. def test_inv(self):
  2150. assert_no_overwrite(inv, [(3, 3)])
  2151. def test_det(self):
  2152. assert_no_overwrite(det, [(3, 3)])
  2153. def test_lstsq(self):
  2154. assert_no_overwrite(lstsq, [(3, 2), (3,)])
  2155. def test_pinv(self):
  2156. assert_no_overwrite(pinv, [(3, 3)])
  2157. def test_pinvh(self):
  2158. assert_no_overwrite(pinvh, [(3, 3)])
  2159. class TestSolveCirculant:
  2160. def test_basic1(self):
  2161. c = np.array([1, 2, 3, 5])
  2162. b = np.array([1, -1, 1, 0])
  2163. x = solve_circulant(c, b)
  2164. y = solve(circulant(c), b)
  2165. assert_allclose(x, y)
  2166. def test_basic2(self):
  2167. # b is a 2-d matrix.
  2168. c = np.array([1, 2, -3, -5])
  2169. b = np.arange(12).reshape(4, 3)
  2170. x = solve_circulant(c, b)
  2171. y = solve(circulant(c), b)
  2172. assert_allclose(x, y)
  2173. def test_basic3(self):
  2174. # b is a 3-d matrix.
  2175. c = np.array([1, 2, -3, -5])
  2176. b = np.arange(24).reshape(4, 3, 2)
  2177. x = solve_circulant(c, b)
  2178. y = solve(circulant(c), b.reshape(4, -1)).reshape(b.shape)
  2179. assert_allclose(x, y)
  2180. def test_complex(self):
  2181. # Complex b and c
  2182. c = np.array([1+2j, -3, 4j, 5])
  2183. b = np.arange(8).reshape(4, 2) + 0.5j
  2184. x = solve_circulant(c, b)
  2185. y = solve(circulant(c), b)
  2186. assert_allclose(x, y)
  2187. def test_random_b_and_c(self):
  2188. # Random b and c
  2189. rng = np.random.RandomState(54321)
  2190. c = rng.standard_normal(50)
  2191. b = rng.standard_normal(50)
  2192. x = solve_circulant(c, b)
  2193. y = solve(circulant(c), b)
  2194. assert_allclose(x, y)
  2195. def test_singular(self):
  2196. # c gives a singular circulant matrix.
  2197. c = np.array([1, 1, 0, 0])
  2198. b = np.array([1, 2, 3, 4])
  2199. x = solve_circulant(c, b, singular='lstsq')
  2200. y, res, rnk, s = lstsq(circulant(c), b)
  2201. assert_allclose(x, y)
  2202. assert_raises(LinAlgError, solve_circulant, x, y)
  2203. def test_axis_args(self):
  2204. # Test use of caxis, baxis and outaxis.
  2205. # c has shape (2, 1, 4)
  2206. c = np.array([[[-1, 2.5, 3, 3.5]], [[1, 6, 6, 6.5]]])
  2207. # b has shape (3, 4)
  2208. b = np.array([[0, 0, 1, 1], [1, 1, 0, 0], [1, -1, 0, 0]])
  2209. x = solve_circulant(c, b, baxis=1)
  2210. assert_equal(x.shape, (4, 2, 3))
  2211. expected = np.empty_like(x)
  2212. expected[:, 0, :] = solve(circulant(c[0].ravel()), b.T)
  2213. expected[:, 1, :] = solve(circulant(c[1].ravel()), b.T)
  2214. assert_allclose(x, expected)
  2215. x = solve_circulant(c, b, baxis=1, outaxis=-1)
  2216. assert_equal(x.shape, (2, 3, 4))
  2217. assert_allclose(np.moveaxis(x, -1, 0), expected)
  2218. # np.swapaxes(c, 1, 2) has shape (2, 4, 1); b.T has shape (4, 3).
  2219. x = solve_circulant(np.swapaxes(c, 1, 2), b.T, caxis=1)
  2220. assert_equal(x.shape, (4, 2, 3))
  2221. assert_allclose(x, expected)
  2222. def test_native_list_arguments(self):
  2223. # Same as test_basic1 using python's native list.
  2224. c = [1, 2, 3, 5]
  2225. b = [1, -1, 1, 0]
  2226. x = solve_circulant(c, b)
  2227. y = solve(circulant(c), b)
  2228. assert_allclose(x, y)
  2229. @pytest.mark.parametrize('dt_c', [int, float, np.float32, complex, np.complex64])
  2230. @pytest.mark.parametrize('dt_b', [int, float, np.float32, complex, np.complex64])
  2231. def test_empty(self, dt_c, dt_b):
  2232. c = np.array([], dtype=dt_c)
  2233. b = np.array([], dtype=dt_b)
  2234. x = solve_circulant(c, b)
  2235. assert x.shape == (0,)
  2236. assert x.dtype == solve_circulant(np.arange(3, dtype=dt_c),
  2237. np.ones(3, dtype=dt_b)).dtype
  2238. b = np.empty((0, 0), dtype=dt_b)
  2239. x1 = solve_circulant(c, b)
  2240. assert x1.shape == (0, 0)
  2241. assert x1.dtype == x.dtype
  2242. class TestMatrix_Balance:
  2243. @skip_xp_invalid_arg
  2244. def test_string_arg(self):
  2245. assert_raises(ValueError, matrix_balance, 'Some string for fail')
  2246. def test_infnan_arg(self):
  2247. assert_raises(ValueError, matrix_balance,
  2248. np.array([[1, 2], [3, np.inf]]))
  2249. assert_raises(ValueError, matrix_balance,
  2250. np.array([[1, 2], [3, np.nan]]))
  2251. def test_scaling(self):
  2252. _, y = matrix_balance(np.array([[1000, 1], [1000, 0]]))
  2253. # Pre/post LAPACK 3.5.0 gives the same result up to an offset
  2254. # since in each case col norm is x1000 greater and
  2255. # 1000 / 32 ~= 1 * 32 hence balanced with 2 ** 5.
  2256. assert_allclose(np.diff(np.log2(np.diag(y))), [5])
  2257. def test_scaling_order(self):
  2258. A = np.array([[1, 0, 1e-4], [1, 1, 1e-2], [1e4, 1e2, 1]])
  2259. x, y = matrix_balance(A)
  2260. assert_allclose(solve(y, A).dot(y), x)
  2261. def test_separate(self):
  2262. _, (y, z) = matrix_balance(np.array([[1000, 1], [1000, 0]]),
  2263. separate=1)
  2264. assert_equal(np.diff(np.log2(y)), [5])
  2265. assert_allclose(z, np.arange(2))
  2266. def test_permutation(self):
  2267. A = block_diag(np.ones((2, 2)), np.tril(np.ones((2, 2))),
  2268. np.ones((3, 3)))
  2269. x, (y, z) = matrix_balance(A, separate=1)
  2270. assert_allclose(y, np.ones_like(y))
  2271. assert_allclose(z, np.array([0, 1, 6, 5, 4, 3, 2]))
  2272. def test_perm_and_scaling(self):
  2273. # Matrix with its diagonal removed
  2274. cases = ( # Case 0
  2275. np.array([[0., 0., 0., 0., 0.000002],
  2276. [0., 0., 0., 0., 0.],
  2277. [2., 2., 0., 0., 0.],
  2278. [2., 2., 0., 0., 0.],
  2279. [0., 0., 0.000002, 0., 0.]]),
  2280. # Case 1 user reported GH-7258
  2281. np.array([[-0.5, 0., 0., 0.],
  2282. [0., -1., 0., 0.],
  2283. [1., 0., -0.5, 0.],
  2284. [0., 1., 0., -1.]]),
  2285. # Case 2 user reported GH-7258
  2286. np.array([[-3., 0., 1., 0.],
  2287. [-1., -1., -0., 1.],
  2288. [-3., -0., -0., 0.],
  2289. [-1., -0., 1., -1.]])
  2290. )
  2291. for A in cases:
  2292. x, y = matrix_balance(A)
  2293. x, (s, p) = matrix_balance(A, separate=1)
  2294. ip = np.empty_like(p)
  2295. ip[p] = np.arange(A.shape[0])
  2296. assert_allclose(y, np.diag(s)[ip, :])
  2297. assert_allclose(solve(y, A).dot(y), x)
  2298. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  2299. def test_empty(self, dt):
  2300. a = np.empty((0, 0), dtype=dt)
  2301. b, t = matrix_balance(a)
  2302. assert b.size == 0
  2303. assert t.size == 0
  2304. b_n, t_n = matrix_balance(np.eye(2, dtype=dt))
  2305. assert b.dtype == b_n.dtype
  2306. assert t.dtype == t_n.dtype
  2307. b, (scale, perm) = matrix_balance(a, separate=True)
  2308. assert b.size == 0
  2309. assert scale.size == 0
  2310. assert perm.size == 0
  2311. b_n, (scale_n, perm_n) = matrix_balance(a, separate=True)
  2312. assert b.dtype == b_n.dtype
  2313. assert scale.dtype == scale_n.dtype
  2314. assert perm.dtype == perm_n.dtype
  2315. class TestDTypes:
  2316. """Check backwards compatibility for dtypes vs scipy 1.16."""
  2317. def get_arr2D(self, tcode):
  2318. # return a valid 2D array for the typecode
  2319. if tcode == 'M':
  2320. return np.eye(2, dtype='datetime64[ms]')
  2321. elif tcode == 'V':
  2322. return np.asarray([[b'a', b'b'], [b'c', b'd']], dtype='V')
  2323. else:
  2324. return np.eye(2, dtype=tcode)
  2325. def get_arr1D(self, tcode):
  2326. # return a valid 1D array for the typecode
  2327. if tcode == 'M':
  2328. return np.ones(2, dtype='datetime64[ms]')
  2329. elif tcode == 'V':
  2330. return np.asarray([b'a', b'b'], dtype='V')
  2331. else:
  2332. return np.ones(2, dtype=tcode)
  2333. @pytest.mark.parametrize("tcode", np.typecodes['All'])
  2334. def test_inv(self, tcode):
  2335. # check backwards compat vs scipy 1.16
  2336. a = self.get_arr2D(tcode)
  2337. if tcode in 'SUVO':
  2338. # raises
  2339. with pytest.raises(ValueError):
  2340. inv(a)
  2341. else:
  2342. # passes
  2343. inv(a)
  2344. @pytest.mark.parametrize("tcode", np.typecodes['All'])
  2345. def test_det(self, tcode):
  2346. a = self.get_arr2D(tcode)
  2347. is_arm = platform.machine() == 'arm64'
  2348. is_windows = os.name == 'nt'
  2349. failing_tcodes = 'SUVOmM'
  2350. if not (is_arm or is_windows):
  2351. failing_tcodes += 'gG'
  2352. if tcode in failing_tcodes:
  2353. # raises
  2354. with pytest.raises(TypeError):
  2355. det(a)
  2356. else:
  2357. # passes
  2358. det(a)
  2359. @pytest.mark.filterwarnings("ignore:Casting complex values")
  2360. @pytest.mark.parametrize("tcode_a", np.typecodes['All'])
  2361. @pytest.mark.parametrize("tcode_b", np.typecodes['All'])
  2362. def test_solve(self, tcode_a, tcode_b):
  2363. a = self.get_arr2D(tcode_a)
  2364. b = self.get_arr1D(tcode_b)
  2365. can_combine = True
  2366. try:
  2367. np.result_type(tcode_a, tcode_b)
  2368. except TypeError:
  2369. can_combine = False
  2370. if not can_combine:
  2371. # np.exceptions.DTypePromotionError subclasses TypeError
  2372. with pytest.raises(TypeError):
  2373. solve(a, b)
  2374. elif tcode_a in 'SUVO' or tcode_b in 'VO':
  2375. with pytest.raises(ValueError):
  2376. solve(a, b)
  2377. else:
  2378. solve(a, b)