_basic.py 86 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424
  1. #
  2. # Author: Pearu Peterson, March 2002
  3. #
  4. # w/ additions by Travis Oliphant, March 2002
  5. # and Jake Vanderplas, August 2012
  6. import warnings
  7. from warnings import warn
  8. from itertools import product
  9. import numpy as np
  10. from numpy import atleast_1d, atleast_2d
  11. from scipy._lib._util import _apply_over_batch
  12. from .lapack import (
  13. get_lapack_funcs, _normalize_lapack_dtype, _normalize_lapack_dtype1,
  14. _compute_lwork
  15. )
  16. from ._misc import LinAlgError, _datacopied, LinAlgWarning
  17. from ._decomp import _asarray_validated
  18. from . import _decomp, _decomp_svd
  19. from ._solve_toeplitz import levinson
  20. from ._cythonized_array_utils import (find_det_from_lu, bandwidth, issymmetric,
  21. ishermitian)
  22. from . import _batched_linalg
  23. __all__ = ['solve', 'solve_triangular', 'solveh_banded', 'solve_banded',
  24. 'solve_toeplitz', 'solve_circulant', 'inv', 'det', 'lstsq',
  25. 'pinv', 'pinvh', 'matrix_balance', 'matmul_toeplitz']
  26. # Linear equations
  27. def _solve_check(n, info, lamch=None, rcond=None):
  28. """ Check arguments during the different steps of the solution phase """
  29. if info < 0:
  30. raise ValueError(f'LAPACK reported an illegal value in {-info}-th argument.')
  31. elif 0 < info or rcond == 0:
  32. raise LinAlgError('Matrix is singular.')
  33. if lamch is None:
  34. return
  35. E = lamch('E')
  36. if not (rcond >= E): # `rcond < E` doesn't handle NaN
  37. warn(f'Ill-conditioned matrix (rcond={rcond:.6g}): '
  38. 'result may not be accurate.',
  39. LinAlgWarning, stacklevel=3)
  40. def _find_matrix_structure(a):
  41. n = a.shape[0]
  42. n_below, n_above = bandwidth(a)
  43. if n_below == n_above == 0:
  44. kind = 'diagonal'
  45. elif n_above == 0:
  46. kind = 'lower triangular'
  47. elif n_below == 0:
  48. kind = 'upper triangular'
  49. elif n_above <= 1 and n_below <= 1 and n > 3:
  50. kind = 'tridiagonal'
  51. elif np.issubdtype(a.dtype, np.complexfloating) and ishermitian(a):
  52. kind = 'hermitian'
  53. elif issymmetric(a):
  54. kind = 'symmetric'
  55. else:
  56. kind = 'general'
  57. return kind, n_below, n_above
  58. def _format_emit_errors_warnings(err_lst):
  59. """Format/emit errors/warnings from a lowlevel batched routine.
  60. See inv, solve.
  61. """
  62. singular, lapack_err, ill_cond = [], [], []
  63. for i, dct in enumerate(err_lst):
  64. if dct["is_singular"]:
  65. singular.append(i)
  66. if dct["lapack_info"] < 0:
  67. lapack_err.append(f"slice {i} emits lapack info={dct['lapack_info']}")
  68. if dct["is_ill_conditioned"]:
  69. ill_cond.append(f"slice {i} has rcond = {dct['rcond']}")
  70. if singular:
  71. raise LinAlgError(
  72. f"A singular matrix detected: slice(s) {singular} are singular."
  73. )
  74. if lapack_err:
  75. raise ValueError(f"Internal LAPACK errors: {','.join(lapack_err)}.")
  76. if ill_cond:
  77. warnings.warn(
  78. f"An ill-conditioned matrix detected: {','.join(ill_cond)}.",
  79. LinAlgWarning,
  80. stacklevel=3
  81. )
  82. def solve(a, b, lower=False, overwrite_a=False,
  83. overwrite_b=False, check_finite=True, assume_a=None,
  84. transposed=False):
  85. """
  86. Solve the equation ``a @ x = b`` for ``x``,
  87. where `a` is a square matrix.
  88. If the data matrix is known to be a particular type then supplying the
  89. corresponding string to ``assume_a`` key chooses the dedicated solver.
  90. The available options are
  91. ============================= ================================
  92. diagonal 'diagonal'
  93. tridiagonal 'tridiagonal'
  94. banded 'banded'
  95. upper triangular 'upper triangular'
  96. lower triangular 'lower triangular'
  97. symmetric 'symmetric' (or 'sym')
  98. hermitian 'hermitian' (or 'her')
  99. symmetric positive definite 'positive definite' (or 'pos')
  100. general 'general' (or 'gen')
  101. ============================= ================================
  102. Array argument(s) of this function may have additional
  103. "batch" dimensions prepended to the core shape. In this case, the array is treated
  104. as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
  105. Parameters
  106. ----------
  107. a : array_like, shape (..., N, N)
  108. Square left-hand side matrix or a batch of matrices.
  109. b : (..., N, NRHS) array_like
  110. Input data for the right hand side or a batch of right-hand sides.
  111. lower : bool, default: False
  112. Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``.
  113. If True, the calculation uses only the data in the lower triangle of `a`;
  114. entries above the diagonal are ignored. If False (default), the
  115. calculation uses only the data in the upper triangle of `a`; entries
  116. below the diagonal are ignored.
  117. overwrite_a : bool, default: False
  118. Allow overwriting data in `a` (may enhance performance).
  119. overwrite_b : bool, default: False
  120. Allow overwriting data in `b` (may enhance performance).
  121. check_finite : bool, default: True
  122. Whether to check that the input matrices contain only finite numbers.
  123. Disabling may give a performance gain, but may result in problems
  124. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  125. assume_a : str, optional
  126. Valid entries are described above.
  127. If omitted or ``None``, checks are performed to identify structure so the
  128. appropriate solver can be called.
  129. transposed : bool, default: False
  130. If True, solve ``a.T @ x == b``. Raises `NotImplementedError`
  131. for complex `a`.
  132. Returns
  133. -------
  134. x : ndarray, shape (N, NRHS) or (..., N)
  135. The solution array.
  136. Raises
  137. ------
  138. ValueError
  139. If size mismatches detected or input a is not square.
  140. LinAlgError
  141. If the computation fails because of matrix singularity.
  142. LinAlgWarning
  143. If an ill-conditioned input a is detected.
  144. NotImplementedError
  145. If transposed is True and input a is a complex matrix.
  146. Notes
  147. -----
  148. If the input b matrix is a 1-D array with N elements, when supplied
  149. together with an NxN input a, it is assumed as a valid column vector
  150. despite the apparent size mismatch. This is compatible with the
  151. numpy.dot() behavior and the returned result is still 1-D array.
  152. The general, symmetric, Hermitian and positive definite solutions are
  153. obtained via calling ?GETRF/?GETRS, ?SYSV, ?HESV, and ?POTRF/?POTRS routines of
  154. LAPACK respectively.
  155. The datatype of the arrays define which solver is called regardless
  156. of the values. In other words, even when the complex array entries have
  157. precisely zero imaginary parts, the complex solver will be called based
  158. on the data type of the array.
  159. Examples
  160. --------
  161. Given `a` and `b`, solve for `x`:
  162. >>> import numpy as np
  163. >>> a = np.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]])
  164. >>> b = np.array([2, 4, -1])
  165. >>> from scipy.linalg import solve
  166. >>> x = solve(a, b)
  167. >>> x
  168. array([ 2., -2., 9.])
  169. >>> a @ x == b
  170. array([ True, True, True], dtype=bool)
  171. Batches of matrices are supported, with and without structure detection:
  172. >>> a = np.arange(12).reshape(3, 2, 2) # a batch of 3 2x2 matrices
  173. >>> A = a.transpose(0, 2, 1) @ a # A is a batch of 3 positive definite matrices
  174. >>> b = np.ones(2)
  175. >>> solve(A, b) # this automatically detects that A is pos.def.
  176. array([[ 1. , -0.5],
  177. [ 3. , -2.5],
  178. [ 5. , -4.5]])
  179. >>> solve(A, b, assume_a='pos') # bypass structucture detection
  180. array([[ 1. , -0.5],
  181. [ 3. , -2.5],
  182. [ 5. , -4.5]])
  183. """
  184. if assume_a in ['banded']:
  185. # TODO: handle these structures in this function
  186. return solve0(
  187. a, b, lower=lower, overwrite_a=overwrite_a, overwrite_b=overwrite_b,
  188. check_finite=check_finite, assume_a=assume_a, transposed=transposed
  189. )
  190. # keep the numbers in sync with C
  191. structure = {
  192. None: -1,
  193. 'general': 0, 'gen': 0,
  194. 'diagonal': 11,
  195. 'tridiagonal': 31,
  196. 'upper triangular': 21,
  197. 'lower triangular': 22,
  198. 'pos' : 101, 'positive definite': 101,
  199. 'sym' : 201, 'symmetric': 201,
  200. 'her' : 211, 'hermitian': 211,
  201. }.get(assume_a, 'unknown')
  202. if structure == 'unknown':
  203. raise ValueError(f'{assume_a} is not a recognized matrix structure')
  204. a1 = np.atleast_2d(_asarray_validated(a, check_finite=check_finite))
  205. b1 = np.atleast_1d(_asarray_validated(b, check_finite=check_finite))
  206. a1, b1 = _ensure_dtype_cdsz(a1, b1) # XXX; b upcasts a?
  207. a1, overwrite_a = _normalize_lapack_dtype(a1, overwrite_a)
  208. if a1.ndim < 2:
  209. raise ValueError(f"Expected at least ndim=2, got {a1.ndim=}")
  210. if a1.shape[-1] != a1.shape[-2]:
  211. raise ValueError(f"Expected square matrix, got {a1.shape=}")
  212. # backwards compatibility
  213. if np.issubdtype(a1.dtype, np.complexfloating) and transposed:
  214. raise NotImplementedError('scipy.linalg.solve can currently '
  215. 'not solve a^T x = b or a^H x = b '
  216. 'for complex matrices.')
  217. if not (a1.flags['ALIGNED'] or a1.dtype.byteorder == '='):
  218. overwrite_a = True
  219. a1 = a1.copy()
  220. if not (b1.flags['ALIGNED'] or b1.dtype.byteorder == '='):
  221. overwrite_a = True
  222. b1 = b1.copy()
  223. # align the shape of b with a: 1. make b1 at least 2D
  224. b_is_1D = b1.ndim == 1
  225. if b_is_1D:
  226. b1 = b1[:, None]
  227. a_is_scalar = a1.size == 1
  228. if b1.shape[-2] != a1.shape[-1] and not a_is_scalar:
  229. raise ValueError(f"incompatible shapes: {a1.shape=} and {b1.shape=}")
  230. # 2. broadcast the batch dimensions of b1 and a1
  231. batch_shape = np.broadcast_shapes(a1.shape[:-2], b1.shape[:-2])
  232. a1 = np.broadcast_to(a1, batch_shape + a1.shape[-2:])
  233. b1 = np.broadcast_to(b1, batch_shape + b1.shape[-2:])
  234. # catch empty inputs
  235. if a1.size == 0 or b1.size == 0:
  236. x = np.empty_like(b1)
  237. if b_is_1D:
  238. x = x[..., 0]
  239. return x
  240. if a_is_scalar:
  241. if a1.item() == 0:
  242. raise LinAlgError("A singular matrix detected.")
  243. out = b1 / a1
  244. return out[..., 0] if b_is_1D else out
  245. # XXX a1.ndim > 2 ; b1.ndim > 2
  246. # XXX can do something if a1 C ordered & transposed==True ?
  247. overwrite_a = overwrite_a and (a1.ndim == 2) and (a1.flags["F_CONTIGUOUS"])
  248. overwrite_b = overwrite_b and (b1.ndim <= 2) and (b1.flags["F_CONTIGUOUS"])
  249. # heavy lifting
  250. x, err_lst = _batched_linalg._solve(
  251. a1, b1, structure, lower, transposed, overwrite_a, overwrite_b
  252. )
  253. if err_lst:
  254. _format_emit_errors_warnings(err_lst)
  255. if b_is_1D:
  256. x = x[..., 0]
  257. return x
  258. @_apply_over_batch(('a', 2), ('b', '1|2'))
  259. def solve0(a, b, lower=False, overwrite_a=False,
  260. overwrite_b=False, check_finite=True, assume_a=None,
  261. transposed=False):
  262. """
  263. Solve the equation ``a @ x = b`` for ``x``,
  264. where `a` is a square matrix.
  265. If the data matrix is known to be a particular type then supplying the
  266. corresponding string to ``assume_a`` key chooses the dedicated solver.
  267. The available options are
  268. ============================= ================================
  269. diagonal 'diagonal'
  270. tridiagonal 'tridiagonal'
  271. banded 'banded'
  272. upper triangular 'upper triangular'
  273. lower triangular 'lower triangular'
  274. symmetric 'symmetric' (or 'sym')
  275. hermitian 'hermitian' (or 'her')
  276. symmetric positive definite 'positive definite' (or 'pos')
  277. general 'general' (or 'gen')
  278. ============================= ================================
  279. Parameters
  280. ----------
  281. a : (N, N) array_like
  282. Square input data
  283. b : (N, NRHS) array_like
  284. Input data for the right hand side.
  285. lower : bool, default: False
  286. Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``.
  287. If True, the calculation uses only the data in the lower triangle of `a`;
  288. entries above the diagonal are ignored. If False (default), the
  289. calculation uses only the data in the upper triangle of `a`; entries
  290. below the diagonal are ignored.
  291. overwrite_a : bool, default: False
  292. Allow overwriting data in `a` (may enhance performance).
  293. overwrite_b : bool, default: False
  294. Allow overwriting data in `b` (may enhance performance).
  295. check_finite : bool, default: True
  296. Whether to check that the input matrices contain only finite numbers.
  297. Disabling may give a performance gain, but may result in problems
  298. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  299. assume_a : str, optional
  300. Valid entries are described above.
  301. If omitted or ``None``, checks are performed to identify structure so the
  302. appropriate solver can be called.
  303. transposed : bool, default: False
  304. If True, solve ``a.T @ x == b``. Raises `NotImplementedError`
  305. for complex `a`.
  306. Returns
  307. -------
  308. x : (N, NRHS) ndarray
  309. The solution array.
  310. Raises
  311. ------
  312. ValueError
  313. If size mismatches detected or input a is not square.
  314. LinAlgError
  315. If the computation fails because of matrix singularity.
  316. LinAlgWarning
  317. If an ill-conditioned input a is detected.
  318. NotImplementedError
  319. If transposed is True and input a is a complex matrix.
  320. Notes
  321. -----
  322. If the input b matrix is a 1-D array with N elements, when supplied
  323. together with an NxN input a, it is assumed as a valid column vector
  324. despite the apparent size mismatch. This is compatible with the
  325. numpy.dot() behavior and the returned result is still 1-D array.
  326. The general, symmetric, Hermitian and positive definite solutions are
  327. obtained via calling ?GESV, ?SYSV, ?HESV, and ?POSV routines of
  328. LAPACK respectively.
  329. The datatype of the arrays define which solver is called regardless
  330. of the values. In other words, even when the complex array entries have
  331. precisely zero imaginary parts, the complex solver will be called based
  332. on the data type of the array.
  333. Examples
  334. --------
  335. Given `a` and `b`, solve for `x`:
  336. >>> import numpy as np
  337. >>> a = np.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]])
  338. >>> b = np.array([2, 4, -1])
  339. >>> from scipy import linalg
  340. >>> x = linalg.solve(a, b)
  341. >>> x
  342. array([ 2., -2., 9.])
  343. >>> np.dot(a, x) == b
  344. array([ True, True, True], dtype=bool)
  345. """
  346. # Flags for 1-D or N-D right-hand side
  347. b_is_1D = False
  348. # check finite after determining structure
  349. a1 = atleast_2d(_asarray_validated(a, check_finite=False))
  350. b1 = atleast_1d(_asarray_validated(b, check_finite=False))
  351. a1, b1 = _ensure_dtype_cdsz(a1, b1)
  352. n = a1.shape[0]
  353. overwrite_a = overwrite_a or _datacopied(a1, a)
  354. overwrite_b = overwrite_b or _datacopied(b1, b)
  355. if a1.shape[0] != a1.shape[1]:
  356. raise ValueError('Input a needs to be a square matrix.')
  357. if n != b1.shape[0]:
  358. # Last chance to catch 1x1 scalar a and 1-D b arrays
  359. if not (n == 1 and b1.size != 0):
  360. raise ValueError('Input b has to have same number of rows as '
  361. 'input a')
  362. # accommodate empty arrays
  363. if b1.size == 0:
  364. dt = solve(np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype)).dtype
  365. return np.empty_like(b1, dtype=dt)
  366. # regularize 1-D b arrays to 2D
  367. if b1.ndim == 1:
  368. if n == 1:
  369. b1 = b1[None, :]
  370. else:
  371. b1 = b1[:, None]
  372. b_is_1D = True
  373. if assume_a not in {None, 'diagonal', 'tridiagonal', 'banded', 'lower triangular',
  374. 'upper triangular', 'symmetric', 'hermitian',
  375. 'positive definite', 'general', 'sym', 'her', 'pos', 'gen'}:
  376. raise ValueError(f'{assume_a} is not a recognized matrix structure')
  377. # for a real matrix, describe it as "symmetric", not "hermitian"
  378. # (lapack doesn't know what to do with real hermitian matrices)
  379. if assume_a in {'hermitian', 'her'} and not np.iscomplexobj(a1):
  380. assume_a = 'symmetric'
  381. n_below, n_above = None, None
  382. if assume_a is None:
  383. assume_a, n_below, n_above = _find_matrix_structure(a1)
  384. # Get the correct lamch function.
  385. # The LAMCH functions only exists for S and D
  386. # So for complex values we have to convert to real/double.
  387. if a1.dtype.char in 'fF': # single precision
  388. lamch = get_lapack_funcs('lamch', dtype='f')
  389. else:
  390. lamch = get_lapack_funcs('lamch', dtype='d')
  391. # Since the I-norm and 1-norm are the same for symmetric matrices
  392. # we can collect them all in this one call
  393. # Note however, that when issuing 'gen' and form!='none', then
  394. # the I-norm should be used
  395. if transposed:
  396. trans = 1
  397. norm = 'I'
  398. if np.iscomplexobj(a1):
  399. raise NotImplementedError('scipy.linalg.solve can currently '
  400. 'not solve a^T x = b or a^H x = b '
  401. 'for complex matrices.')
  402. else:
  403. trans = 0
  404. norm = '1'
  405. # Currently we do not have the other forms of the norm calculators
  406. # lansy, lanpo, lanhe.
  407. # However, in any case they only reduce computations slightly...
  408. if assume_a == 'diagonal':
  409. anorm = _matrix_norm_diagonal(a1, check_finite)
  410. elif assume_a == 'tridiagonal':
  411. anorm = _matrix_norm_tridiagonal(norm, a1, check_finite)
  412. elif assume_a == 'banded':
  413. n_below, n_above = bandwidth(a1) if n_below is None else (n_below, n_above)
  414. a2, n_below, n_above = ((a1.T, n_above, n_below) if transposed
  415. else (a1, n_below, n_above))
  416. ab = _to_banded(n_below, n_above, a2)
  417. anorm = _matrix_norm_banded(n_below, n_above, norm, ab, check_finite)
  418. elif assume_a in {'lower triangular', 'upper triangular'}:
  419. anorm = _matrix_norm_triangular(assume_a, norm, a1, check_finite)
  420. else:
  421. anorm = _matrix_norm_general(norm, a1, check_finite)
  422. info, rcond = 0, np.inf
  423. # Generalized case 'gesv'
  424. if assume_a in {'general', 'gen'}:
  425. gecon, getrf, getrs = get_lapack_funcs(('gecon', 'getrf', 'getrs'),
  426. (a1, b1))
  427. lu, ipvt, info = getrf(a1, overwrite_a=overwrite_a)
  428. _solve_check(n, info)
  429. x, info = getrs(lu, ipvt, b1,
  430. trans=trans, overwrite_b=overwrite_b)
  431. _solve_check(n, info)
  432. rcond, info = gecon(lu, anorm, norm=norm)
  433. # Hermitian case 'hesv'
  434. elif assume_a in {'hermitian', 'her'}:
  435. hecon, hesv, hesv_lw = get_lapack_funcs(('hecon', 'hesv',
  436. 'hesv_lwork'), (a1, b1))
  437. lwork = _compute_lwork(hesv_lw, n, lower)
  438. lu, ipvt, x, info = hesv(a1, b1, lwork=lwork,
  439. lower=lower,
  440. overwrite_a=overwrite_a,
  441. overwrite_b=overwrite_b)
  442. _solve_check(n, info)
  443. rcond, info = hecon(lu, ipvt, anorm, lower=lower)
  444. # Symmetric case 'sysv'
  445. elif assume_a in {'symmetric', 'sym'}:
  446. sycon, sysv, sysv_lw = get_lapack_funcs(('sycon', 'sysv',
  447. 'sysv_lwork'), (a1, b1))
  448. lwork = _compute_lwork(sysv_lw, n, lower)
  449. lu, ipvt, x, info = sysv(a1, b1, lwork=lwork,
  450. lower=lower,
  451. overwrite_a=overwrite_a,
  452. overwrite_b=overwrite_b)
  453. _solve_check(n, info)
  454. rcond, info = sycon(lu, ipvt, anorm, lower=lower)
  455. # Diagonal case
  456. elif assume_a == 'diagonal':
  457. diag_a = np.diag(a1)
  458. x = (b1.T / diag_a).T
  459. abs_diag_a = np.abs(diag_a)
  460. diag_min = abs_diag_a.min()
  461. rcond = diag_min if diag_min == 0 else diag_min / abs_diag_a.max()
  462. # Tri-diagonal case
  463. elif assume_a == 'tridiagonal':
  464. a1 = a1.T if transposed else a1
  465. dl, d, du = np.diag(a1, -1), np.diag(a1, 0), np.diag(a1, 1)
  466. _gttrf, _gttrs, _gtcon = get_lapack_funcs(('gttrf', 'gttrs', 'gtcon'), (a1, b1))
  467. dl, d, du, du2, ipiv, info = _gttrf(dl, d, du)
  468. _solve_check(n, info)
  469. x, info = _gttrs(dl, d, du, du2, ipiv, b1, overwrite_b=overwrite_b)
  470. _solve_check(n, info)
  471. rcond, info = _gtcon(dl, d, du, du2, ipiv, anorm)
  472. # Banded case
  473. elif assume_a == 'banded':
  474. gbsv, gbcon = get_lapack_funcs(('gbsv', 'gbcon'), (a1, b1))
  475. # Next two lines copied from `solve_banded`
  476. a2 = np.zeros((2*n_below + n_above + 1, ab.shape[1]), dtype=gbsv.dtype)
  477. a2[n_below:, :] = ab
  478. lu, piv, x, info = gbsv(n_below, n_above, a2, b1,
  479. overwrite_ab=True, overwrite_b=overwrite_b)
  480. _solve_check(n, info)
  481. rcond, info = gbcon(n_below, n_above, lu, piv, anorm)
  482. # Triangular case
  483. elif assume_a in {'lower triangular', 'upper triangular'}:
  484. lower = assume_a == 'lower triangular'
  485. x, info = _solve_triangular(a1, b1, lower=lower, overwrite_b=overwrite_b,
  486. trans=transposed)
  487. _solve_check(n, info)
  488. _trcon = get_lapack_funcs(('trcon'), (a1, b1))
  489. rcond, info = _trcon(a1, uplo='L' if lower else 'U')
  490. # Positive definite case 'posv'
  491. else:
  492. pocon, posv = get_lapack_funcs(('pocon', 'posv'),
  493. (a1, b1))
  494. lu, x, info = posv(a1, b1, lower=lower,
  495. overwrite_a=overwrite_a,
  496. overwrite_b=overwrite_b)
  497. _solve_check(n, info)
  498. rcond, info = pocon(lu, anorm)
  499. _solve_check(n, info, lamch, rcond)
  500. if b_is_1D:
  501. x = x.ravel()
  502. return x
  503. def _matrix_norm_diagonal(a, check_finite):
  504. # Equivalent of dlange for diagonal matrix, assuming
  505. # norm is either 'I' or '1' (really just not the Frobenius norm)
  506. d = np.diag(a)
  507. d = np.asarray_chkfinite(d) if check_finite else d
  508. return np.abs(d).max()
  509. def _matrix_norm_tridiagonal(norm, a, check_finite):
  510. # Equivalent of dlange for tridiagonal matrix, assuming
  511. # norm is either 'I' or '1'
  512. if norm == 'I':
  513. a = a.T
  514. # Context to avoid warning before error in cases like -inf + inf
  515. with np.errstate(invalid='ignore'):
  516. d = np.abs(np.diag(a))
  517. d[1:] += np.abs(np.diag(a, 1))
  518. d[:-1] += np.abs(np.diag(a, -1))
  519. d = np.asarray_chkfinite(d) if check_finite else d
  520. return d.max()
  521. def _matrix_norm_triangular(structure, norm, a, check_finite):
  522. a = np.asarray_chkfinite(a) if check_finite else a
  523. lantr = get_lapack_funcs('lantr', (a,))
  524. return lantr(norm, a, 'L' if structure == 'lower triangular' else 'U' )
  525. def _matrix_norm_banded(kl, ku, norm, ab, check_finite):
  526. ab = np.asarray_chkfinite(ab) if check_finite else ab
  527. langb = get_lapack_funcs('langb', (ab,))
  528. return langb(norm, kl, ku, ab)
  529. def _matrix_norm_general(norm, a, check_finite):
  530. a = np.asarray_chkfinite(a) if check_finite else a
  531. lange = get_lapack_funcs('lange', (a,))
  532. return lange(norm, a)
  533. def _to_banded(n_below, n_above, a):
  534. n = a.shape[0]
  535. rows = n_above + n_below + 1
  536. ab = np.zeros((rows, n), dtype=a.dtype)
  537. ab[n_above] = np.diag(a)
  538. for i in range(1, n_above + 1):
  539. ab[n_above - i, i:] = np.diag(a, i)
  540. for i in range(1, n_below + 1):
  541. ab[n_above + i, :-i] = np.diag(a, -i)
  542. return ab
  543. def _ensure_dtype_cdsz(*arrays):
  544. # Ensure that the dtype of arrays is one of the standard types
  545. # compatible with LAPACK functions (single or double precision
  546. # real or complex).
  547. dtype = np.result_type(*arrays)
  548. if not np.issubdtype(dtype, np.inexact):
  549. return (array.astype(np.float64) for array in arrays)
  550. complex = np.issubdtype(dtype, np.complexfloating)
  551. if np.finfo(dtype).bits <= 32:
  552. dtype = np.complex64 if complex else np.float32
  553. elif np.finfo(dtype).bits >= 64:
  554. dtype = np.complex128 if complex else np.float64
  555. return (array.astype(dtype, copy=False) for array in arrays)
  556. @_apply_over_batch(('a', 2), ('b', '1|2'))
  557. def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
  558. overwrite_b=False, check_finite=True):
  559. """
  560. Solve the equation ``a @ x = b`` for ``x``, where `a` is a triangular matrix.
  561. Parameters
  562. ----------
  563. a : (M, M) array_like
  564. A triangular matrix
  565. b : (M,) or (M, N) array_like
  566. Right-hand side matrix in ``a x = b``
  567. lower : bool, optional
  568. Use only data contained in the lower triangle of `a`.
  569. Default is to use upper triangle.
  570. trans : {0, 1, 2, 'N', 'T', 'C'}, optional
  571. Type of system to solve:
  572. ======== =========
  573. trans system
  574. ======== =========
  575. 0 or 'N' a x = b
  576. 1 or 'T' a^T x = b
  577. 2 or 'C' a^H x = b
  578. ======== =========
  579. unit_diagonal : bool, optional
  580. If True, diagonal elements of `a` are assumed to be 1 and
  581. will not be referenced.
  582. overwrite_b : bool, optional
  583. Allow overwriting data in `b` (may enhance performance)
  584. check_finite : bool, optional
  585. Whether to check that the input matrices contain only finite numbers.
  586. Disabling may give a performance gain, but may result in problems
  587. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  588. Returns
  589. -------
  590. x : (M,) or (M, N) ndarray
  591. Solution to the system ``a x = b``. Shape of return matches `b`.
  592. Raises
  593. ------
  594. LinAlgError
  595. If `a` is singular
  596. Notes
  597. -----
  598. .. versionadded:: 0.9.0
  599. Examples
  600. --------
  601. Solve the lower triangular system a x = b, where::
  602. [3 0 0 0] [4]
  603. a = [2 1 0 0] b = [2]
  604. [1 0 1 0] [4]
  605. [1 1 1 1] [2]
  606. >>> import numpy as np
  607. >>> from scipy.linalg import solve_triangular
  608. >>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]])
  609. >>> b = np.array([4, 2, 4, 2])
  610. >>> x = solve_triangular(a, b, lower=True)
  611. >>> x
  612. array([ 1.33333333, -0.66666667, 2.66666667, -1.33333333])
  613. >>> a.dot(x) # Check the result
  614. array([ 4., 2., 4., 2.])
  615. """
  616. a1 = _asarray_validated(a, check_finite=check_finite)
  617. b1 = _asarray_validated(b, check_finite=check_finite)
  618. if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
  619. raise ValueError('expected square matrix')
  620. if a1.shape[0] != b1.shape[0]:
  621. raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible')
  622. # accommodate empty arrays
  623. if b1.size == 0:
  624. dt_nonempty = solve_triangular(
  625. np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype)
  626. ).dtype
  627. return np.empty_like(b1, dtype=dt_nonempty)
  628. overwrite_b = overwrite_b or _datacopied(b1, b)
  629. x, _ = _solve_triangular(a1, b1, trans, lower, unit_diagonal, overwrite_b)
  630. return x
  631. # solve_triangular without the input validation
  632. def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False,
  633. overwrite_b=False):
  634. trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans)
  635. trtrs, = get_lapack_funcs(('trtrs',), (a1, b1))
  636. if a1.flags.f_contiguous or trans == 2:
  637. x, info = trtrs(a1, b1, overwrite_b=overwrite_b, lower=lower,
  638. trans=trans, unitdiag=unit_diagonal)
  639. else:
  640. # transposed system is solved since trtrs expects Fortran ordering
  641. x, info = trtrs(a1.T, b1, overwrite_b=overwrite_b, lower=not lower,
  642. trans=not trans, unitdiag=unit_diagonal)
  643. if info == 0:
  644. return x, info
  645. if info > 0:
  646. raise LinAlgError(f"singular matrix: resolution failed at diagonal {info-1}")
  647. raise ValueError(f'illegal value in {-info}-th argument of internal trtrs')
  648. def solve_banded(l_and_u, ab, b, overwrite_ab=False, overwrite_b=False,
  649. check_finite=True):
  650. """
  651. Solve the equation ``a @ x = b`` for ``x``, where ``a`` is the banded matrix
  652. defined by `ab`.
  653. The matrix a is stored in `ab` using the matrix diagonal ordered form::
  654. ab[u + i - j, j] == a[i,j]
  655. Example of `ab` (shape of a is (6,6), `u` =1, `l` =2)::
  656. * a01 a12 a23 a34 a45
  657. a00 a11 a22 a33 a44 a55
  658. a10 a21 a32 a43 a54 *
  659. a20 a31 a42 a53 * *
  660. The documentation is written assuming array arguments are of specified
  661. "core" shapes. However, array argument(s) of this function may have additional
  662. "batch" dimensions prepended to the core shape. In this case, the array is treated
  663. as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
  664. Parameters
  665. ----------
  666. (l, u) : (integer, integer)
  667. Number of non-zero lower and upper diagonals
  668. ab : (`l` + `u` + 1, M) array_like
  669. Banded matrix
  670. b : (M,) or (M, K) array_like
  671. Right-hand side
  672. overwrite_ab : bool, optional
  673. Discard data in `ab` (may enhance performance)
  674. overwrite_b : bool, optional
  675. Discard data in `b` (may enhance performance)
  676. check_finite : bool, optional
  677. Whether to check that the input matrices contain only finite numbers.
  678. Disabling may give a performance gain, but may result in problems
  679. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  680. Returns
  681. -------
  682. x : (M,) or (M, K) ndarray
  683. The solution to the system a x = b. Returned shape depends on the
  684. shape of `b`.
  685. Examples
  686. --------
  687. Solve the banded system a x = b, where::
  688. [5 2 -1 0 0] [0]
  689. [1 4 2 -1 0] [1]
  690. a = [0 1 3 2 -1] b = [2]
  691. [0 0 1 2 2] [2]
  692. [0 0 0 1 1] [3]
  693. There is one nonzero diagonal below the main diagonal (l = 1), and
  694. two above (u = 2). The diagonal banded form of the matrix is::
  695. [* * -1 -1 -1]
  696. ab = [* 2 2 2 2]
  697. [5 4 3 2 1]
  698. [1 1 1 1 *]
  699. >>> import numpy as np
  700. >>> from scipy.linalg import solve_banded
  701. >>> ab = np.array([[0, 0, -1, -1, -1],
  702. ... [0, 2, 2, 2, 2],
  703. ... [5, 4, 3, 2, 1],
  704. ... [1, 1, 1, 1, 0]])
  705. >>> b = np.array([0, 1, 2, 2, 3])
  706. >>> x = solve_banded((1, 2), ab, b)
  707. >>> x
  708. array([-2.37288136, 3.93220339, -4. , 4.3559322 , -1.3559322 ])
  709. """
  710. (nlower, nupper) = l_and_u
  711. return _solve_banded(nlower, nupper, ab, b, overwrite_ab=overwrite_ab,
  712. overwrite_b=overwrite_b, check_finite=check_finite)
  713. @_apply_over_batch(('nlower', 0), ('nupper', 0), ('ab', 2), ('b', '1|2'))
  714. def _solve_banded(nlower, nupper, ab, b, overwrite_ab, overwrite_b, check_finite):
  715. a1 = _asarray_validated(ab, check_finite=check_finite, as_inexact=True)
  716. b1 = _asarray_validated(b, check_finite=check_finite, as_inexact=True)
  717. # Validate shapes.
  718. if a1.shape[-1] != b1.shape[0]:
  719. raise ValueError("shapes of ab and b are not compatible.")
  720. if nlower + nupper + 1 != a1.shape[0]:
  721. raise ValueError(
  722. f"invalid values for the number of lower and upper diagonals: l+u+1 "
  723. f"({nlower + nupper + 1}) does not equal ab.shape[0] ({ab.shape[0]})"
  724. )
  725. # accommodate empty arrays
  726. if b1.size == 0:
  727. dt = solve(np.eye(1, dtype=a1.dtype), np.ones(1, dtype=b1.dtype)).dtype
  728. return np.empty_like(b1, dtype=dt)
  729. overwrite_b = overwrite_b or _datacopied(b1, b)
  730. if a1.shape[-1] == 1:
  731. b2 = np.array(b1, copy=(not overwrite_b))
  732. # a1.shape[-1] == 1 -> original matrix is 1x1. Typically, the user
  733. # will pass u = l = 0 and `a1` will be 1x1. However, the rest of the
  734. # function works with unnecessary rows in `a1` as long as
  735. # `a1[u + i - j, j] == a[i,j]`. In the 1x1 case, we want i = j = 0,
  736. # so the diagonal is in row `u` of `a1`. See gh-8906.
  737. b2 /= a1[nupper, 0]
  738. return b2
  739. if nlower == nupper == 1:
  740. overwrite_ab = overwrite_ab or _datacopied(a1, ab)
  741. gtsv, = get_lapack_funcs(('gtsv',), (a1, b1))
  742. du = a1[0, 1:]
  743. d = a1[1, :]
  744. dl = a1[2, :-1]
  745. du2, d, du, x, info = gtsv(dl, d, du, b1, overwrite_ab, overwrite_ab,
  746. overwrite_ab, overwrite_b)
  747. else:
  748. gbsv, = get_lapack_funcs(('gbsv',), (a1, b1))
  749. a2 = np.zeros((2*nlower + nupper + 1, a1.shape[1]), dtype=gbsv.dtype)
  750. a2[nlower:, :] = a1
  751. lu, piv, x, info = gbsv(nlower, nupper, a2, b1, overwrite_ab=True,
  752. overwrite_b=overwrite_b)
  753. if info == 0:
  754. return x
  755. if info > 0:
  756. raise LinAlgError("singular matrix")
  757. raise ValueError(f'illegal value in {-info}-th argument of internal gbsv/gtsv')
  758. @_apply_over_batch(('a', 2), ('b', '1|2'))
  759. def solveh_banded(ab, b, overwrite_ab=False, overwrite_b=False, lower=False,
  760. check_finite=True):
  761. """
  762. Solve the equation ``a @ x = b`` for ``x``, where ``a`` is the
  763. Hermitian positive-definite banded matrix defined by `ab`.
  764. Uses Thomas' Algorithm, which is more efficient than standard LU
  765. factorization, but should only be used for Hermitian positive-definite
  766. matrices.
  767. The matrix ``a`` is stored in `ab` either in lower diagonal or upper
  768. diagonal ordered form:
  769. ab[u + i - j, j] == a[i,j] (if upper form; i <= j)
  770. ab[ i - j, j] == a[i,j] (if lower form; i >= j)
  771. Example of `ab` (shape of ``a`` is (6, 6), number of upper diagonals,
  772. ``u`` =2)::
  773. upper form:
  774. * * a02 a13 a24 a35
  775. * a01 a12 a23 a34 a45
  776. a00 a11 a22 a33 a44 a55
  777. lower form:
  778. a00 a11 a22 a33 a44 a55
  779. a10 a21 a32 a43 a54 *
  780. a20 a31 a42 a53 * *
  781. Cells marked with * are not used.
  782. Parameters
  783. ----------
  784. ab : (``u`` + 1, M) array_like
  785. Banded matrix
  786. b : (M,) or (M, K) array_like
  787. Right-hand side
  788. overwrite_ab : bool, optional
  789. Discard data in `ab` (may enhance performance)
  790. overwrite_b : bool, optional
  791. Discard data in `b` (may enhance performance)
  792. lower : bool, optional
  793. Is the matrix in the lower form. (Default is upper form)
  794. check_finite : bool, optional
  795. Whether to check that the input matrices contain only finite numbers.
  796. Disabling may give a performance gain, but may result in problems
  797. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  798. Returns
  799. -------
  800. x : (M,) or (M, K) ndarray
  801. The solution to the system ``a x = b``. Shape of return matches shape
  802. of `b`.
  803. Notes
  804. -----
  805. In the case of a non-positive definite matrix ``a``, the solver
  806. `solve_banded` may be used.
  807. Examples
  808. --------
  809. Solve the banded system ``A x = b``, where::
  810. [ 4 2 -1 0 0 0] [1]
  811. [ 2 5 2 -1 0 0] [2]
  812. A = [-1 2 6 2 -1 0] b = [2]
  813. [ 0 -1 2 7 2 -1] [3]
  814. [ 0 0 -1 2 8 2] [3]
  815. [ 0 0 0 -1 2 9] [3]
  816. >>> import numpy as np
  817. >>> from scipy.linalg import solveh_banded
  818. ``ab`` contains the main diagonal and the nonzero diagonals below the
  819. main diagonal. That is, we use the lower form:
  820. >>> ab = np.array([[ 4, 5, 6, 7, 8, 9],
  821. ... [ 2, 2, 2, 2, 2, 0],
  822. ... [-1, -1, -1, -1, 0, 0]])
  823. >>> b = np.array([1, 2, 2, 3, 3, 3])
  824. >>> x = solveh_banded(ab, b, lower=True)
  825. >>> x
  826. array([ 0.03431373, 0.45938375, 0.05602241, 0.47759104, 0.17577031,
  827. 0.34733894])
  828. Solve the Hermitian banded system ``H x = b``, where::
  829. [ 8 2-1j 0 0 ] [ 1 ]
  830. H = [2+1j 5 1j 0 ] b = [1+1j]
  831. [ 0 -1j 9 -2-1j] [1-2j]
  832. [ 0 0 -2+1j 6 ] [ 0 ]
  833. In this example, we put the upper diagonals in the array ``hb``:
  834. >>> hb = np.array([[0, 2-1j, 1j, -2-1j],
  835. ... [8, 5, 9, 6 ]])
  836. >>> b = np.array([1, 1+1j, 1-2j, 0])
  837. >>> x = solveh_banded(hb, b)
  838. >>> x
  839. array([ 0.07318536-0.02939412j, 0.11877624+0.17696461j,
  840. 0.10077984-0.23035393j, -0.00479904-0.09358128j])
  841. """
  842. a1 = _asarray_validated(ab, check_finite=check_finite)
  843. b1 = _asarray_validated(b, check_finite=check_finite)
  844. # Validate shapes.
  845. if a1.shape[-1] != b1.shape[0]:
  846. raise ValueError("shapes of ab and b are not compatible.")
  847. # accommodate empty arrays
  848. if b1.size == 0:
  849. dt = solve(np.eye(1, dtype=a1.dtype), np.ones(1, dtype=b1.dtype)).dtype
  850. return np.empty_like(b1, dtype=dt)
  851. overwrite_b = overwrite_b or _datacopied(b1, b)
  852. overwrite_ab = overwrite_ab or _datacopied(a1, ab)
  853. if a1.shape[0] == 2:
  854. ptsv, = get_lapack_funcs(('ptsv',), (a1, b1))
  855. if lower:
  856. d = a1[0, :].real
  857. e = a1[1, :-1]
  858. else:
  859. d = a1[1, :].real
  860. e = a1[0, 1:].conj()
  861. d, du, x, info = ptsv(d, e, b1, overwrite_ab, overwrite_ab,
  862. overwrite_b)
  863. else:
  864. pbsv, = get_lapack_funcs(('pbsv',), (a1, b1))
  865. c, x, info = pbsv(a1, b1, lower=lower, overwrite_ab=overwrite_ab,
  866. overwrite_b=overwrite_b)
  867. if info > 0:
  868. raise LinAlgError(f"{info}th leading minor not positive definite")
  869. if info < 0:
  870. raise ValueError(f'illegal value in {-info}th argument of internal pbsv')
  871. return x
  872. def solve_toeplitz(c_or_cr, b, check_finite=True):
  873. r"""Solve the equation ``T @ x = b`` for ``x``, where ``T`` is a Toeplitz
  874. matrix defined by `c_or_cr`.
  875. The Toeplitz matrix has constant diagonals, with ``c`` as its first column
  876. and ``r`` as its first row. If ``r`` is not given, ``r == conjugate(c)`` is
  877. assumed.
  878. The documentation is written assuming array arguments are of specified
  879. "core" shapes. However, array argument(s) of this function may have additional
  880. "batch" dimensions prepended to the core shape. In this case, the array is treated
  881. as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
  882. Parameters
  883. ----------
  884. c_or_cr : array_like or tuple of (array_like, array_like)
  885. The vector ``c``, or a tuple of arrays (``c``, ``r``). If not
  886. supplied, ``r = conjugate(c)`` is assumed; in this case, if c[0] is
  887. real, the Toeplitz matrix is Hermitian. r[0] is ignored; the first row
  888. of the Toeplitz matrix is ``[c[0], r[1:]]``.
  889. b : (M,) or (M, K) array_like
  890. Right-hand side in ``T x = b``.
  891. check_finite : bool, optional
  892. Whether to check that the input matrices contain only finite numbers.
  893. Disabling may give a performance gain, but may result in problems
  894. (result entirely NaNs) if the inputs do contain infinities or NaNs.
  895. Returns
  896. -------
  897. x : (M,) or (M, K) ndarray
  898. The solution to the system ``T @ x = b``. Shape of return matches shape
  899. of `b`.
  900. See Also
  901. --------
  902. toeplitz : Toeplitz matrix
  903. Notes
  904. -----
  905. The solution is computed using Levinson-Durbin recursion, which is faster
  906. than generic least-squares methods, but can be less numerically stable.
  907. Examples
  908. --------
  909. Solve the Toeplitz system ``T @ x = b``, where::
  910. [ 1 -1 -2 -3] [1]
  911. T = [ 3 1 -1 -2] b = [2]
  912. [ 6 3 1 -1] [2]
  913. [10 6 3 1] [5]
  914. To specify the Toeplitz matrix, only the first column and the first
  915. row are needed.
  916. >>> import numpy as np
  917. >>> c = np.array([1, 3, 6, 10]) # First column of T
  918. >>> r = np.array([1, -1, -2, -3]) # First row of T
  919. >>> b = np.array([1, 2, 2, 5])
  920. >>> from scipy.linalg import solve_toeplitz, toeplitz
  921. >>> x = solve_toeplitz((c, r), b)
  922. >>> x
  923. array([ 1.66666667, -1. , -2.66666667, 2.33333333])
  924. Check the result by creating the full Toeplitz matrix and
  925. multiplying it by ``x``. We should get `b`.
  926. >>> T = toeplitz(c, r)
  927. >>> T.dot(x)
  928. array([ 1., 2., 2., 5.])
  929. """
  930. # If numerical stability of this algorithm is a problem, a future
  931. # developer might consider implementing other O(N^2) Toeplitz solvers,
  932. # such as GKO (https://www.jstor.org/stable/2153371) or Bareiss.
  933. c, r = c_or_cr if isinstance(c_or_cr, tuple) else (c_or_cr, np.conjugate(c_or_cr))
  934. return _solve_toeplitz(c, r, b, check_finite)
  935. @_apply_over_batch(('c', 1), ('r', 1), ('b', '1|2'))
  936. def _solve_toeplitz(c, r, b, check_finite):
  937. r, c, b, dtype, b_shape = _validate_args_for_toeplitz_ops(
  938. (c, r), b, check_finite, keep_b_shape=True)
  939. # accommodate empty arrays
  940. if b.size == 0:
  941. return np.empty_like(b)
  942. # Form a 1-D array of values to be used in the matrix, containing a
  943. # reversed copy of r[1:], followed by c.
  944. vals = np.concatenate((r[-1:0:-1], c))
  945. if b is None:
  946. raise ValueError('illegal value, `b` is a required argument')
  947. if b.ndim == 1:
  948. x, _ = levinson(vals, np.ascontiguousarray(b))
  949. else:
  950. x = np.column_stack([levinson(vals, np.ascontiguousarray(b[:, i]))[0]
  951. for i in range(b.shape[1])])
  952. x = x.reshape(*b_shape)
  953. return x
  954. def _get_axis_len(aname, a, axis):
  955. ax = axis
  956. if ax < 0:
  957. ax += a.ndim
  958. if 0 <= ax < a.ndim:
  959. return a.shape[ax]
  960. raise ValueError(f"'{aname}axis' entry is out of bounds")
  961. def solve_circulant(c, b, singular='raise', tol=None,
  962. caxis=-1, baxis=0, outaxis=0):
  963. """Solve the equation ``C @ x = b`` for ``x``, where ``C`` is a
  964. circulant matrix defined by `c`.
  965. `C` is the circulant matrix associated with the vector `c`.
  966. The system is solved by doing division in Fourier space. The
  967. calculation is::
  968. x = ifft(fft(b) / fft(c))
  969. where `fft` and `ifft` are the fast Fourier transform and its inverse,
  970. respectively. For a large vector `c`, this is *much* faster than
  971. solving the system with the full circulant matrix.
  972. Parameters
  973. ----------
  974. c : array_like
  975. The coefficients of the circulant matrix.
  976. b : array_like
  977. Right-hand side matrix in ``a x = b``.
  978. singular : str, optional
  979. This argument controls how a near singular circulant matrix is
  980. handled. If `singular` is "raise" and the circulant matrix is
  981. near singular, a `LinAlgError` is raised. If `singular` is
  982. "lstsq", the least squares solution is returned. Default is "raise".
  983. tol : float, optional
  984. If any eigenvalue of the circulant matrix has an absolute value
  985. that is less than or equal to `tol`, the matrix is considered to be
  986. near singular. If not given, `tol` is set to::
  987. tol = abs_eigs.max() * abs_eigs.size * np.finfo(np.float64).eps
  988. where `abs_eigs` is the array of absolute values of the eigenvalues
  989. of the circulant matrix.
  990. caxis : int
  991. When `c` has dimension greater than 1, it is viewed as a collection
  992. of circulant vectors. In this case, `caxis` is the axis of `c` that
  993. holds the vectors of circulant coefficients.
  994. baxis : int
  995. When `b` has dimension greater than 1, it is viewed as a collection
  996. of vectors. In this case, `baxis` is the axis of `b` that holds the
  997. right-hand side vectors.
  998. outaxis : int
  999. When `c` or `b` are multidimensional, the value returned by
  1000. `solve_circulant` is multidimensional. In this case, `outaxis` is
  1001. the axis of the result that holds the solution vectors.
  1002. Returns
  1003. -------
  1004. x : ndarray
  1005. Solution to the system ``C x = b``.
  1006. Raises
  1007. ------
  1008. LinAlgError
  1009. If the circulant matrix associated with `c` is near singular.
  1010. See Also
  1011. --------
  1012. circulant : circulant matrix
  1013. Notes
  1014. -----
  1015. For a 1-D vector `c` with length `m`, and an array `b`
  1016. with shape ``(m, ...)``,
  1017. solve_circulant(c, b)
  1018. returns the same result as
  1019. solve(circulant(c), b)
  1020. where `solve` and `circulant` are from `scipy.linalg`.
  1021. .. versionadded:: 0.16.0
  1022. Examples
  1023. --------
  1024. >>> import numpy as np
  1025. >>> from scipy.linalg import solve_circulant, solve, circulant, lstsq
  1026. >>> c = np.array([2, 2, 4])
  1027. >>> b = np.array([1, 2, 3])
  1028. >>> solve_circulant(c, b)
  1029. array([ 0.75, -0.25, 0.25])
  1030. Compare that result to solving the system with `scipy.linalg.solve`:
  1031. >>> solve(circulant(c), b)
  1032. array([ 0.75, -0.25, 0.25])
  1033. A singular example:
  1034. >>> c = np.array([1, 1, 0, 0])
  1035. >>> b = np.array([1, 2, 3, 4])
  1036. Calling ``solve_circulant(c, b)`` will raise a `LinAlgError`. For the
  1037. least square solution, use the option ``singular='lstsq'``:
  1038. >>> solve_circulant(c, b, singular='lstsq')
  1039. array([ 0.25, 1.25, 2.25, 1.25])
  1040. Compare to `scipy.linalg.lstsq`:
  1041. >>> x, resid, rnk, s = lstsq(circulant(c), b)
  1042. >>> x
  1043. array([ 0.25, 1.25, 2.25, 1.25])
  1044. A broadcasting example:
  1045. Suppose we have the vectors of two circulant matrices stored in an array
  1046. with shape (2, 5), and three `b` vectors stored in an array with shape
  1047. (3, 5). For example,
  1048. >>> c = np.array([[1.5, 2, 3, 0, 0], [1, 1, 4, 3, 2]])
  1049. >>> b = np.arange(15).reshape(-1, 5)
  1050. We want to solve all combinations of circulant matrices and `b` vectors,
  1051. with the result stored in an array with shape (2, 3, 5). When we
  1052. disregard the axes of `c` and `b` that hold the vectors of coefficients,
  1053. the shapes of the collections are (2,) and (3,), respectively, which are
  1054. not compatible for broadcasting. To have a broadcast result with shape
  1055. (2, 3), we add a trivial dimension to `c`: ``c[:, np.newaxis, :]`` has
  1056. shape (2, 1, 5). The last dimension holds the coefficients of the
  1057. circulant matrices, so when we call `solve_circulant`, we can use the
  1058. default ``caxis=-1``. The coefficients of the `b` vectors are in the last
  1059. dimension of the array `b`, so we use ``baxis=-1``. If we use the
  1060. default `outaxis`, the result will have shape (5, 2, 3), so we'll use
  1061. ``outaxis=-1`` to put the solution vectors in the last dimension.
  1062. >>> x = solve_circulant(c[:, np.newaxis, :], b, baxis=-1, outaxis=-1)
  1063. >>> x.shape
  1064. (2, 3, 5)
  1065. >>> np.set_printoptions(precision=3) # For compact output of numbers.
  1066. >>> x
  1067. array([[[-0.118, 0.22 , 1.277, -0.142, 0.302],
  1068. [ 0.651, 0.989, 2.046, 0.627, 1.072],
  1069. [ 1.42 , 1.758, 2.816, 1.396, 1.841]],
  1070. [[ 0.401, 0.304, 0.694, -0.867, 0.377],
  1071. [ 0.856, 0.758, 1.149, -0.412, 0.831],
  1072. [ 1.31 , 1.213, 1.603, 0.042, 1.286]]])
  1073. Check by solving one pair of `c` and `b` vectors (cf. ``x[1, 1, :]``):
  1074. >>> solve_circulant(c[1], b[1, :])
  1075. array([ 0.856, 0.758, 1.149, -0.412, 0.831])
  1076. """
  1077. c = np.atleast_1d(c)
  1078. nc = _get_axis_len("c", c, caxis)
  1079. b = np.atleast_1d(b)
  1080. nb = _get_axis_len("b", b, baxis)
  1081. if nc != nb:
  1082. raise ValueError(f'Shapes of c {c.shape} and b {b.shape} are incompatible')
  1083. # accommodate empty arrays
  1084. if b.size == 0:
  1085. dt = solve_circulant(np.arange(3, dtype=c.dtype),
  1086. np.ones(3, dtype=b.dtype)).dtype
  1087. return np.empty_like(b, dtype=dt)
  1088. fc = np.fft.fft(np.moveaxis(c, caxis, -1), axis=-1)
  1089. abs_fc = np.abs(fc)
  1090. if tol is None:
  1091. # This is the same tolerance as used in np.linalg.matrix_rank.
  1092. tol = abs_fc.max(axis=-1) * nc * np.finfo(np.float64).eps
  1093. if tol.shape != ():
  1094. tol = tol.reshape(tol.shape + (1,))
  1095. else:
  1096. tol = np.atleast_1d(tol)
  1097. near_zeros = abs_fc <= tol
  1098. is_near_singular = np.any(near_zeros)
  1099. if is_near_singular:
  1100. if singular == 'raise':
  1101. raise LinAlgError("near singular circulant matrix.")
  1102. else:
  1103. # Replace the small values with 1 to avoid errors in the
  1104. # division fb/fc below.
  1105. fc[near_zeros] = 1
  1106. fb = np.fft.fft(np.moveaxis(b, baxis, -1), axis=-1)
  1107. q = fb / fc
  1108. if is_near_singular:
  1109. # `near_zeros` is a boolean array, same shape as `c`, that is
  1110. # True where `fc` is (near) zero. `q` is the broadcasted result
  1111. # of fb / fc, so to set the values of `q` to 0 where `fc` is near
  1112. # zero, we use a mask that is the broadcast result of an array
  1113. # of True values shaped like `b` with `near_zeros`.
  1114. mask = np.ones_like(b, dtype=bool) & near_zeros
  1115. q[mask] = 0
  1116. x = np.fft.ifft(q, axis=-1)
  1117. if not (np.iscomplexobj(c) or np.iscomplexobj(b)):
  1118. x = x.real
  1119. if outaxis != -1:
  1120. x = np.moveaxis(x, -1, outaxis)
  1121. return x
  1122. # matrix inversion
  1123. def inv(a, overwrite_a=False, check_finite=True, *, assume_a=None, lower=False):
  1124. r"""
  1125. Compute the inverse of a matrix.
  1126. If the data matrix is known to be a particular type then supplying the
  1127. corresponding string to ``assume_a`` key chooses the dedicated solver.
  1128. The available options are
  1129. ============================= ================================
  1130. general 'general' (or 'gen')
  1131. diagonal 'diagonal'
  1132. upper triangular 'upper triangular'
  1133. lower triangular 'lower triangular'
  1134. symmetric positive definite 'pos'
  1135. symmetric 'sym'
  1136. Hermitian 'her'
  1137. ============================= ================================
  1138. For the 'pos' option, only the triangle of the input matrix specified in
  1139. the `lower` argument is used, and the other triangle is not referenced.
  1140. Likewise, an explicit `assume_a='diagonal'` means that off-diagonal elements
  1141. are not referenced.
  1142. Array argument(s) of this function may have additional
  1143. "batch" dimensions prepended to the core shape. In this case, the array is treated
  1144. as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
  1145. Parameters
  1146. ----------
  1147. a : array_like, shape (..., M, M)
  1148. Square matrix (or a batch of matrices) to be inverted.
  1149. overwrite_a : bool, optional
  1150. Discard data in `a` (may improve performance). Default is False.
  1151. check_finite : bool, optional
  1152. Whether to check that the input matrix contains only finite numbers.
  1153. Disabling may give a performance gain, but may result in problems
  1154. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  1155. assume_a : str, optional
  1156. Valid entries are described above.
  1157. If omitted or ``None``, checks are performed to identify structure so the
  1158. appropriate solver can be called.
  1159. lower : bool, optional
  1160. Ignored unless `assume_a` is one of 'sym', 'her', or 'pos'. If True, the
  1161. calculation uses only the data in the lower triangle of `a`; entries above the
  1162. diagonal are ignored. If False (default), the calculation uses only the data in
  1163. the upper triangle of `a`; entries below the diagonal are ignored.
  1164. Returns
  1165. -------
  1166. ainv : ndarray
  1167. Inverse of the matrix `a`.
  1168. Raises
  1169. ------
  1170. LinAlgError
  1171. If `a` is singular.
  1172. ValueError
  1173. If `a` is not square, or not 2D.
  1174. Examples
  1175. --------
  1176. >>> import numpy as np
  1177. >>> from scipy import linalg
  1178. >>> a = np.array([[1., 2.], [3., 4.]])
  1179. >>> linalg.inv(a)
  1180. array([[-2. , 1. ],
  1181. [ 1.5, -0.5]])
  1182. >>> np.dot(a, linalg.inv(a))
  1183. array([[ 1., 0.],
  1184. [ 0., 1.]])
  1185. Notes
  1186. -----
  1187. The input array ``a`` may represent a single matrix or a collection (a.k.a.
  1188. a "batch") of square matrices. For example, if ``a.shape == (4, 3, 2, 2)``, it is
  1189. interpreted as a ``(4, 3)``-shaped batch of :math:`2\times 2` matrices.
  1190. This routine checks the condition number of the `a` matrix and emits a
  1191. `LinAlgWarning` for ill-conditioned inputs.
  1192. """
  1193. a1 = _asarray_validated(a, check_finite=check_finite)
  1194. if a1.ndim < 2:
  1195. raise ValueError(f"Expected at least ndim=2, got {a1.ndim=}")
  1196. if a1.shape[-1] != a1.shape[-2]:
  1197. raise ValueError(f"Expected square matrix, got {a1.shape=}")
  1198. # accommodate empty matrices
  1199. if a1.size == 0:
  1200. dt = inv(np.eye(2, dtype=a1.dtype)).dtype
  1201. return np.empty_like(a1, dtype=dt)
  1202. # Also check if dtype is LAPACK compatible
  1203. a1, overwrite_a = _normalize_lapack_dtype(a1, overwrite_a)
  1204. if not (a1.flags['ALIGNED'] or a1.dtype.byteorder == '='):
  1205. overwrite_a = True
  1206. a1 = a1.copy()
  1207. # XXX can relax a1.ndim == 2?
  1208. overwrite_a = overwrite_a and (a1.ndim == 2) and (a1.flags["F_CONTIGUOUS"])
  1209. # keep the numbers in sync with C at `linalg/src/_common_array_utils.hh`
  1210. structure = {
  1211. None: -1,
  1212. 'general': 0, 'gen': 0,
  1213. 'diagonal': 11,
  1214. 'upper triangular': 21,
  1215. 'lower triangular': 22,
  1216. 'pos' : 101,
  1217. 'sym' : 201,
  1218. 'her' : 211,
  1219. }[assume_a]
  1220. # a1 is well behaved, invert it.
  1221. inv_a, err_lst = _batched_linalg._inv(a1, structure, overwrite_a, lower)
  1222. if err_lst:
  1223. _format_emit_errors_warnings(err_lst)
  1224. return inv_a
  1225. # Determinant
  1226. def det(a, overwrite_a=False, check_finite=True):
  1227. """
  1228. Compute the determinant of a matrix
  1229. The determinant is a scalar that is a function of the associated square
  1230. matrix coefficients. The determinant value is zero for singular matrices.
  1231. Array argument(s) of this function may have additional
  1232. "batch" dimensions prepended to the core shape. In this case, the array is treated
  1233. as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
  1234. Parameters
  1235. ----------
  1236. a : (..., M, M) array_like
  1237. Input array to compute determinants for.
  1238. overwrite_a : bool, optional
  1239. Allow overwriting data in a (may enhance performance).
  1240. check_finite : bool, optional
  1241. Whether to check that the input matrix contains only finite numbers.
  1242. Disabling may give a performance gain, but may result in problems
  1243. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  1244. Returns
  1245. -------
  1246. det : (...) float or complex
  1247. Determinant of `a`. For stacked arrays, a scalar is returned for each
  1248. (m, m) slice in the last two dimensions of the input. For example, an
  1249. input of shape (p, q, m, m) will produce a result of shape (p, q). If
  1250. all dimensions are 1 a scalar is returned regardless of ndim.
  1251. Notes
  1252. -----
  1253. The determinant is computed by performing an LU factorization of the
  1254. input with LAPACK routine 'getrf', and then calculating the product of
  1255. diagonal entries of the U factor.
  1256. Even if the input array is single precision (float32 or complex64), the
  1257. result will be returned in double precision (float64 or complex128) to
  1258. prevent overflows.
  1259. Examples
  1260. --------
  1261. >>> import numpy as np
  1262. >>> from scipy import linalg
  1263. >>> a = np.array([[1,2,3], [4,5,6], [7,8,9]]) # A singular matrix
  1264. >>> linalg.det(a)
  1265. 0.0
  1266. >>> b = np.array([[0,2,3], [4,5,6], [7,8,9]])
  1267. >>> linalg.det(b)
  1268. 3.0
  1269. >>> # An array with the shape (3, 2, 2, 2)
  1270. >>> c = np.array([[[[1., 2.], [3., 4.]],
  1271. ... [[5., 6.], [7., 8.]]],
  1272. ... [[[9., 10.], [11., 12.]],
  1273. ... [[13., 14.], [15., 16.]]],
  1274. ... [[[17., 18.], [19., 20.]],
  1275. ... [[21., 22.], [23., 24.]]]])
  1276. >>> linalg.det(c) # The resulting shape is (3, 2)
  1277. array([[-2., -2.],
  1278. [-2., -2.],
  1279. [-2., -2.]])
  1280. >>> linalg.det(c[0, 0]) # Confirm the (0, 0) slice, [[1, 2], [3, 4]]
  1281. -2.0
  1282. """
  1283. # The goal is to end up with a writable contiguous array to pass to Cython
  1284. # First we check and make arrays.
  1285. a1 = np.asarray_chkfinite(a) if check_finite else np.asarray(a)
  1286. if a1.ndim < 2:
  1287. raise ValueError('The input array must be at least two-dimensional.')
  1288. if a1.shape[-1] != a1.shape[-2]:
  1289. raise ValueError('Last 2 dimensions of the array must be square'
  1290. f' but received shape {a1.shape}.')
  1291. # Also check if dtype is LAPACK compatible
  1292. a1, overwrite_a = _normalize_lapack_dtype1(a1, overwrite_a)
  1293. # Empty array has determinant 1 because math.
  1294. if min(*a1.shape) == 0:
  1295. dtyp = np.float64 if a1.dtype.char not in 'FD' else np.complex128
  1296. if a1.ndim == 2:
  1297. return dtyp(1.0)
  1298. else:
  1299. return np.ones(shape=a1.shape[:-2], dtype=dtyp)
  1300. # Scalar case
  1301. if a1.shape[-2:] == (1, 1):
  1302. a1 = a1[..., 0, 0]
  1303. if a1.ndim == 0:
  1304. a1 = a1[()]
  1305. # Convert float32 to float64, and complex64 to complex128.
  1306. if a1.dtype.char in 'dD':
  1307. return a1
  1308. return a1.astype('d') if a1.dtype.char == 'f' else a1.astype('D')
  1309. # Then check overwrite permission
  1310. if not _datacopied(a1, a): # "a" still alive through "a1"
  1311. if not overwrite_a:
  1312. # Data belongs to "a" so make a copy
  1313. a1 = a1.copy(order='C')
  1314. # else: Do nothing we'll use "a" if possible
  1315. # else: a1 has its own data thus free to scratch
  1316. # Then layout checks, might happen that overwrite is allowed but original
  1317. # array was read-only or non-C-contiguous.
  1318. if not (a1.flags['C_CONTIGUOUS'] and a1.flags['WRITEABLE']):
  1319. a1 = a1.copy(order='C')
  1320. if a1.ndim == 2:
  1321. det = find_det_from_lu(a1)
  1322. # Convert float, complex to NumPy scalars
  1323. return (np.float64(det) if np.isrealobj(det) else np.complex128(det))
  1324. # loop over the stacked array, and avoid overflows for single precision
  1325. # Cf. np.linalg.det(np.diag([1e+38, 1e+38]).astype(np.float32))
  1326. dtype_char = a1.dtype.char
  1327. if dtype_char in 'fF':
  1328. dtype_char = 'd' if dtype_char.islower() else 'D'
  1329. det = np.empty(a1.shape[:-2], dtype=dtype_char)
  1330. for ind in product(*[range(x) for x in a1.shape[:-2]]):
  1331. det[ind] = find_det_from_lu(a1[ind])
  1332. return det
  1333. # Linear Least Squares
  1334. @_apply_over_batch(('a', 2), ('b', '1|2'))
  1335. def lstsq(a, b, cond=None, overwrite_a=False, overwrite_b=False,
  1336. check_finite=True, lapack_driver=None):
  1337. """
  1338. Compute least-squares solution to the equation ``a @ x = b``.
  1339. Compute a vector x such that the 2-norm ``|b - A x|`` is minimized.
  1340. Parameters
  1341. ----------
  1342. a : (M, N) array_like
  1343. Left-hand side array
  1344. b : (M,) or (M, K) array_like
  1345. Right hand side array
  1346. cond : float, optional
  1347. Cutoff for 'small' singular values; used to determine effective
  1348. rank of a. Singular values smaller than
  1349. ``cond * largest_singular_value`` are considered zero.
  1350. overwrite_a : bool, optional
  1351. Discard data in `a` (may enhance performance). Default is False.
  1352. overwrite_b : bool, optional
  1353. Discard data in `b` (may enhance performance). Default is False.
  1354. check_finite : bool, optional
  1355. Whether to check that the input matrices contain only finite numbers.
  1356. Disabling may give a performance gain, but may result in problems
  1357. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  1358. lapack_driver : str, optional
  1359. Which LAPACK driver is used to solve the least-squares problem.
  1360. Options are ``'gelsd'``, ``'gelsy'``, ``'gelss'``. Default
  1361. (``'gelsd'``) is a good choice. However, ``'gelsy'`` can be slightly
  1362. faster on many problems. ``'gelss'`` was used historically. It is
  1363. generally slow but uses less memory.
  1364. .. versionadded:: 0.17.0
  1365. Returns
  1366. -------
  1367. x : (N,) or (N, K) ndarray
  1368. Least-squares solution.
  1369. residues : (K,) ndarray or float
  1370. Square of the 2-norm for each column in ``b - a x``, if ``M > N`` and
  1371. ``rank(A) == n`` (returns a scalar if ``b`` is 1-D). Otherwise a
  1372. (0,)-shaped array is returned.
  1373. rank : int
  1374. Effective rank of `a`.
  1375. s : (min(M, N),) ndarray or None
  1376. Singular values of `a`. The condition number of ``a`` is
  1377. ``s[0] / s[-1]``.
  1378. Raises
  1379. ------
  1380. LinAlgError
  1381. If computation does not converge.
  1382. ValueError
  1383. When parameters are not compatible.
  1384. See Also
  1385. --------
  1386. scipy.optimize.nnls : linear least squares with non-negativity constraint
  1387. Notes
  1388. -----
  1389. When ``'gelsy'`` is used as a driver, `residues` is set to a (0,)-shaped
  1390. array and `s` is always ``None``.
  1391. Examples
  1392. --------
  1393. >>> import numpy as np
  1394. >>> from scipy.linalg import lstsq
  1395. >>> import matplotlib.pyplot as plt
  1396. Suppose we have the following data:
  1397. >>> x = np.array([1, 2.5, 3.5, 4, 5, 7, 8.5])
  1398. >>> y = np.array([0.3, 1.1, 1.5, 2.0, 3.2, 6.6, 8.6])
  1399. We want to fit a quadratic polynomial of the form ``y = a + b*x**2``
  1400. to this data. We first form the "design matrix" M, with a constant
  1401. column of 1s and a column containing ``x**2``:
  1402. >>> M = x[:, np.newaxis]**[0, 2]
  1403. >>> M
  1404. array([[ 1. , 1. ],
  1405. [ 1. , 6.25],
  1406. [ 1. , 12.25],
  1407. [ 1. , 16. ],
  1408. [ 1. , 25. ],
  1409. [ 1. , 49. ],
  1410. [ 1. , 72.25]])
  1411. We want to find the least-squares solution to ``M.dot(p) = y``,
  1412. where ``p`` is a vector with length 2 that holds the parameters
  1413. ``a`` and ``b``.
  1414. >>> p, res, rnk, s = lstsq(M, y)
  1415. >>> p
  1416. array([ 0.20925829, 0.12013861])
  1417. Plot the data and the fitted curve.
  1418. >>> plt.plot(x, y, 'o', label='data')
  1419. >>> xx = np.linspace(0, 9, 101)
  1420. >>> yy = p[0] + p[1]*xx**2
  1421. >>> plt.plot(xx, yy, label='least squares fit, $y = a + bx^2$')
  1422. >>> plt.xlabel('x')
  1423. >>> plt.ylabel('y')
  1424. >>> plt.legend(framealpha=1, shadow=True)
  1425. >>> plt.grid(alpha=0.25)
  1426. >>> plt.show()
  1427. """
  1428. a1 = _asarray_validated(a, check_finite=check_finite)
  1429. b1 = _asarray_validated(b, check_finite=check_finite)
  1430. if len(a1.shape) != 2:
  1431. raise ValueError('Input array a should be 2D')
  1432. m, n = a1.shape
  1433. if len(b1.shape) == 2:
  1434. nrhs = b1.shape[1]
  1435. else:
  1436. nrhs = 1
  1437. if m != b1.shape[0]:
  1438. raise ValueError('Shape mismatch: a and b should have the same number'
  1439. f' of rows ({m} != {b1.shape[0]}).')
  1440. if m == 0 or n == 0: # Zero-sized problem, confuses LAPACK
  1441. x = np.zeros((n,) + b1.shape[1:], dtype=np.common_type(a1, b1))
  1442. if n == 0:
  1443. residues = np.linalg.norm(b1, axis=0)**2
  1444. else:
  1445. residues = np.empty((0,))
  1446. return x, residues, 0, np.empty((0,))
  1447. driver = lapack_driver
  1448. if driver is None:
  1449. driver = lstsq.default_lapack_driver
  1450. if driver not in ('gelsd', 'gelsy', 'gelss'):
  1451. raise ValueError(f'LAPACK driver "{driver}" is not found')
  1452. lapack_func, lapack_lwork = get_lapack_funcs((driver,
  1453. f'{driver}_lwork'),
  1454. (a1, b1))
  1455. real_data = True if (lapack_func.dtype.kind == 'f') else False
  1456. if m < n:
  1457. # need to extend b matrix as it will be filled with
  1458. # a larger solution matrix
  1459. if len(b1.shape) == 2:
  1460. b2 = np.zeros((n, nrhs), dtype=lapack_func.dtype)
  1461. b2[:m, :] = b1
  1462. else:
  1463. b2 = np.zeros(n, dtype=lapack_func.dtype)
  1464. b2[:m] = b1
  1465. b1 = b2
  1466. overwrite_a = overwrite_a or _datacopied(a1, a)
  1467. overwrite_b = overwrite_b or _datacopied(b1, b)
  1468. if cond is None:
  1469. cond = np.finfo(lapack_func.dtype).eps
  1470. if driver in ('gelss', 'gelsd'):
  1471. if driver == 'gelss':
  1472. lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
  1473. v, x, s, rank, work, info = lapack_func(a1, b1, cond, lwork,
  1474. overwrite_a=overwrite_a,
  1475. overwrite_b=overwrite_b)
  1476. elif driver == 'gelsd':
  1477. if real_data:
  1478. lwork, iwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
  1479. x, s, rank, info = lapack_func(a1, b1, lwork,
  1480. iwork, cond, False, False)
  1481. else: # complex data
  1482. lwork, rwork, iwork = _compute_lwork(lapack_lwork, m, n,
  1483. nrhs, cond)
  1484. x, s, rank, info = lapack_func(a1, b1, lwork, rwork, iwork,
  1485. cond, False, False)
  1486. if info > 0:
  1487. raise LinAlgError("SVD did not converge in Linear Least Squares")
  1488. if info < 0:
  1489. raise ValueError(
  1490. f'illegal value in {-info}-th argument of internal {lapack_driver}'
  1491. )
  1492. resids = np.asarray([], dtype=x.dtype)
  1493. if m > n:
  1494. x1 = x[:n]
  1495. if rank == n:
  1496. resids = np.sum(np.abs(x[n:])**2, axis=0)
  1497. x = x1
  1498. return x, resids, rank, s
  1499. elif driver == 'gelsy':
  1500. lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
  1501. jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
  1502. v, x, j, rank, info = lapack_func(a1, b1, jptv, cond,
  1503. lwork, False, False)
  1504. if info < 0:
  1505. raise ValueError(f'illegal value in {-info}-th argument of internal gelsy')
  1506. if m > n:
  1507. x1 = x[:n]
  1508. x = x1
  1509. return x, np.array([], x.dtype), rank, None
  1510. lstsq.default_lapack_driver = 'gelsd'
  1511. @_apply_over_batch(('a', 2))
  1512. def pinv(a, *, atol=None, rtol=None, return_rank=False, check_finite=True):
  1513. """
  1514. Compute the (Moore-Penrose) pseudo-inverse of a matrix.
  1515. Calculate a generalized inverse of a matrix using its
  1516. singular-value decomposition ``U @ S @ V`` in the economy mode and picking
  1517. up only the columns/rows that are associated with significant singular
  1518. values.
  1519. If ``s`` is the maximum singular value of ``a``, then the
  1520. significance cut-off value is determined by ``atol + rtol * s``. Any
  1521. singular value below this value is assumed insignificant.
  1522. Parameters
  1523. ----------
  1524. a : (M, N) array_like
  1525. Matrix to be pseudo-inverted.
  1526. atol : float, optional
  1527. Absolute threshold term, default value is 0.
  1528. .. versionadded:: 1.7.0
  1529. rtol : float, optional
  1530. Relative threshold term, default value is ``max(M, N) * eps`` where
  1531. ``eps`` is the machine precision value of the datatype of ``a``.
  1532. .. versionadded:: 1.7.0
  1533. return_rank : bool, optional
  1534. If True, return the effective rank of the matrix.
  1535. check_finite : bool, optional
  1536. Whether to check that the input matrix contains only finite numbers.
  1537. Disabling may give a performance gain, but may result in problems
  1538. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  1539. Returns
  1540. -------
  1541. B : (N, M) ndarray
  1542. The pseudo-inverse of matrix `a`.
  1543. rank : int
  1544. The effective rank of the matrix. Returned if `return_rank` is True.
  1545. Raises
  1546. ------
  1547. LinAlgError
  1548. If SVD computation does not converge.
  1549. See Also
  1550. --------
  1551. pinvh : Moore-Penrose pseudoinverse of a hermitian matrix.
  1552. Notes
  1553. -----
  1554. If ``A`` is invertible then the Moore-Penrose pseudoinverse is exactly
  1555. the inverse of ``A`` [1]_. If ``A`` is not invertible then the
  1556. Moore-Penrose pseudoinverse computes the ``x`` solution to ``Ax = b`` such
  1557. that ``||Ax - b||`` is minimized [1]_.
  1558. References
  1559. ----------
  1560. .. [1] Penrose, R. (1956). On best approximate solutions of linear matrix
  1561. equations. Mathematical Proceedings of the Cambridge Philosophical
  1562. Society, 52(1), 17-19. doi:10.1017/S0305004100030929
  1563. Examples
  1564. --------
  1565. Given an ``m x n`` matrix ``A`` and an ``n x m`` matrix ``B`` the four
  1566. Moore-Penrose conditions are:
  1567. 1. ``ABA = A`` (``B`` is a generalized inverse of ``A``),
  1568. 2. ``BAB = B`` (``A`` is a generalized inverse of ``B``),
  1569. 3. ``(AB)* = AB`` (``AB`` is hermitian),
  1570. 4. ``(BA)* = BA`` (``BA`` is hermitian) [1]_.
  1571. Here, ``A*`` denotes the conjugate transpose. The Moore-Penrose
  1572. pseudoinverse is a unique ``B`` that satisfies all four of these
  1573. conditions and exists for any ``A``. Note that, unlike the standard
  1574. matrix inverse, ``A`` does not have to be a square matrix or have
  1575. linearly independent columns/rows.
  1576. As an example, we can calculate the Moore-Penrose pseudoinverse of a
  1577. random non-square matrix and verify it satisfies the four conditions.
  1578. >>> import numpy as np
  1579. >>> from scipy import linalg
  1580. >>> rng = np.random.default_rng()
  1581. >>> A = rng.standard_normal((9, 6))
  1582. >>> B = linalg.pinv(A)
  1583. >>> np.allclose(A @ B @ A, A) # Condition 1
  1584. True
  1585. >>> np.allclose(B @ A @ B, B) # Condition 2
  1586. True
  1587. >>> np.allclose((A @ B).conj().T, A @ B) # Condition 3
  1588. True
  1589. >>> np.allclose((B @ A).conj().T, B @ A) # Condition 4
  1590. True
  1591. """
  1592. a = _asarray_validated(a, check_finite=check_finite)
  1593. u, s, vh = _decomp_svd.svd(a, full_matrices=False, check_finite=False)
  1594. t = u.dtype.char.lower()
  1595. maxS = np.max(s, initial=0.)
  1596. atol = 0. if atol is None else atol
  1597. rtol = max(a.shape) * np.finfo(t).eps if (rtol is None) else rtol
  1598. if (atol < 0.) or (rtol < 0.):
  1599. raise ValueError("atol and rtol values must be positive.")
  1600. val = atol + maxS * rtol
  1601. rank = np.sum(s > val)
  1602. u = u[:, :rank]
  1603. u /= s[:rank]
  1604. B = (u @ vh[:rank]).conj().T
  1605. if return_rank:
  1606. return B, rank
  1607. else:
  1608. return B
  1609. @_apply_over_batch(('a', 2))
  1610. def pinvh(a, atol=None, rtol=None, lower=True, return_rank=False,
  1611. check_finite=True):
  1612. """
  1613. Compute the (Moore-Penrose) pseudo-inverse of a Hermitian matrix.
  1614. Calculate a generalized inverse of a complex Hermitian/real symmetric
  1615. matrix using its eigenvalue decomposition and including all eigenvalues
  1616. with 'large' absolute value.
  1617. Parameters
  1618. ----------
  1619. a : (N, N) array_like
  1620. Real symmetric or complex hermetian matrix to be pseudo-inverted
  1621. atol : float, optional
  1622. Absolute threshold term, default value is 0.
  1623. .. versionadded:: 1.7.0
  1624. rtol : float, optional
  1625. Relative threshold term, default value is ``N * eps`` where
  1626. ``eps`` is the machine precision value of the datatype of ``a``.
  1627. .. versionadded:: 1.7.0
  1628. lower : bool, optional
  1629. Whether the pertinent array data is taken from the lower or upper
  1630. triangle of `a`. (Default: lower)
  1631. return_rank : bool, optional
  1632. If True, return the effective rank of the matrix.
  1633. check_finite : bool, optional
  1634. Whether to check that the input matrix contains only finite numbers.
  1635. Disabling may give a performance gain, but may result in problems
  1636. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  1637. Returns
  1638. -------
  1639. B : (N, N) ndarray
  1640. The pseudo-inverse of matrix `a`.
  1641. rank : int
  1642. The effective rank of the matrix. Returned if `return_rank` is True.
  1643. Raises
  1644. ------
  1645. LinAlgError
  1646. If eigenvalue algorithm does not converge.
  1647. See Also
  1648. --------
  1649. pinv : Moore-Penrose pseudoinverse of a matrix.
  1650. Examples
  1651. --------
  1652. For a more detailed example see `pinv`.
  1653. >>> import numpy as np
  1654. >>> from scipy.linalg import pinvh
  1655. >>> rng = np.random.default_rng()
  1656. >>> a = rng.standard_normal((9, 6))
  1657. >>> a = np.dot(a, a.T)
  1658. >>> B = pinvh(a)
  1659. >>> np.allclose(a, a @ B @ a)
  1660. True
  1661. >>> np.allclose(B, B @ a @ B)
  1662. True
  1663. """
  1664. a = _asarray_validated(a, check_finite=check_finite)
  1665. s, u = _decomp.eigh(a, lower=lower, check_finite=False, driver='ev')
  1666. t = u.dtype.char.lower()
  1667. maxS = np.max(np.abs(s), initial=0.)
  1668. atol = 0. if atol is None else atol
  1669. rtol = max(a.shape) * np.finfo(t).eps if (rtol is None) else rtol
  1670. if (atol < 0.) or (rtol < 0.):
  1671. raise ValueError("atol and rtol values must be positive.")
  1672. val = atol + maxS * rtol
  1673. above_cutoff = (abs(s) > val)
  1674. psigma_diag = 1.0 / s[above_cutoff]
  1675. u = u[:, above_cutoff]
  1676. B = (u * psigma_diag) @ u.conj().T
  1677. if return_rank:
  1678. return B, len(psigma_diag)
  1679. else:
  1680. return B
  1681. @_apply_over_batch(('A', 2))
  1682. def matrix_balance(A, permute=True, scale=True, separate=False,
  1683. overwrite_a=False):
  1684. """
  1685. Compute a diagonal similarity transformation for row/column balancing.
  1686. The balancing tries to equalize the row and column 1-norms by applying
  1687. a similarity transformation such that the magnitude variation of the
  1688. matrix entries is reflected to the scaling matrices.
  1689. Moreover, if enabled, the matrix is first permuted to isolate the upper
  1690. triangular parts of the matrix and, again if scaling is also enabled,
  1691. only the remaining subblocks are subjected to scaling.
  1692. Parameters
  1693. ----------
  1694. A : (n, n) array_like
  1695. Square data matrix for the balancing.
  1696. permute : bool, optional
  1697. The selector to define whether permutation of A is also performed
  1698. prior to scaling.
  1699. scale : bool, optional
  1700. The selector to turn on and off the scaling. If False, the matrix
  1701. will not be scaled.
  1702. separate : bool, optional
  1703. This switches from returning a full matrix of the transformation
  1704. to a tuple of two separate 1-D permutation and scaling arrays.
  1705. overwrite_a : bool, optional
  1706. This is passed to xGEBAL directly. Essentially, overwrites the result
  1707. to the data. It might increase the space efficiency. See LAPACK manual
  1708. for details. This is False by default.
  1709. Returns
  1710. -------
  1711. B : (n, n) ndarray
  1712. Balanced matrix
  1713. T : (n, n) ndarray
  1714. A possibly permuted diagonal matrix whose nonzero entries are
  1715. integer powers of 2 to avoid numerical truncation errors.
  1716. scale, perm : (n,) ndarray
  1717. If ``separate`` keyword is set to True then instead of the array
  1718. ``T`` above, the scaling and the permutation vectors are given
  1719. separately as a tuple without allocating the full array ``T``.
  1720. Notes
  1721. -----
  1722. The balanced matrix satisfies the following equality
  1723. .. math::
  1724. B = T^{-1} A T
  1725. The scaling coefficients are approximated to the nearest power of 2
  1726. to avoid round-off errors.
  1727. This algorithm is particularly useful for eigenvalue and matrix
  1728. decompositions and in many cases it is already called by various
  1729. LAPACK routines.
  1730. The algorithm is based on the well-known technique of [1]_ and has
  1731. been modified to account for special cases. See [2]_ for details
  1732. which have been implemented since LAPACK v3.5.0. Before this version
  1733. there are corner cases where balancing can actually worsen the
  1734. conditioning. See [3]_ for such examples.
  1735. The code is a wrapper around LAPACK's xGEBAL routine family for matrix
  1736. balancing.
  1737. .. versionadded:: 0.19.0
  1738. References
  1739. ----------
  1740. .. [1] B.N. Parlett and C. Reinsch, "Balancing a Matrix for
  1741. Calculation of Eigenvalues and Eigenvectors", Numerische Mathematik,
  1742. Vol.13(4), 1969, :doi:`10.1007/BF02165404`
  1743. .. [2] R. James, J. Langou, B.R. Lowery, "On matrix balancing and
  1744. eigenvector computation", 2014, :arxiv:`1401.5766`
  1745. .. [3] D.S. Watkins. A case where balancing is harmful.
  1746. Electron. Trans. Numer. Anal, Vol.23, 2006.
  1747. Examples
  1748. --------
  1749. >>> import numpy as np
  1750. >>> from scipy import linalg
  1751. >>> x = np.array([[1,2,0], [9,1,0.01], [1,2,10*np.pi]])
  1752. >>> y, permscale = linalg.matrix_balance(x)
  1753. >>> np.abs(x).sum(axis=0) / np.abs(x).sum(axis=1)
  1754. array([ 3.66666667, 0.4995005 , 0.91312162])
  1755. >>> np.abs(y).sum(axis=0) / np.abs(y).sum(axis=1)
  1756. array([ 1.2 , 1.27041742, 0.92658316]) # may vary
  1757. >>> permscale # only powers of 2 (0.5 == 2^(-1))
  1758. array([[ 0.5, 0. , 0. ], # may vary
  1759. [ 0. , 1. , 0. ],
  1760. [ 0. , 0. , 1. ]])
  1761. """
  1762. A = np.atleast_2d(_asarray_validated(A, check_finite=True))
  1763. if not np.equal(*A.shape):
  1764. raise ValueError('The data matrix for balancing should be square.')
  1765. # accommodate empty arrays
  1766. if A.size == 0:
  1767. b_n, t_n = matrix_balance(np.eye(2, dtype=A.dtype))
  1768. B = np.empty_like(A, dtype=b_n.dtype)
  1769. if separate:
  1770. scaling = np.ones_like(A, shape=len(A))
  1771. perm = np.arange(len(A))
  1772. return B, (scaling, perm)
  1773. return B, np.empty_like(A, dtype=t_n.dtype)
  1774. gebal = get_lapack_funcs(('gebal'), (A,))
  1775. B, lo, hi, ps, info = gebal(A, scale=scale, permute=permute,
  1776. overwrite_a=overwrite_a)
  1777. if info < 0:
  1778. raise ValueError('xGEBAL exited with the internal error '
  1779. f'"illegal value in argument number {-info}.". See '
  1780. 'LAPACK documentation for the xGEBAL error codes.')
  1781. # Separate the permutations from the scalings and then convert to int
  1782. scaling = np.ones_like(ps, dtype=float)
  1783. scaling[lo:hi+1] = ps[lo:hi+1]
  1784. # gebal uses 1-indexing
  1785. ps = ps.astype(int, copy=False) - 1
  1786. n = A.shape[0]
  1787. perm = np.arange(n)
  1788. # LAPACK permutes with the ordering n --> hi, then 0--> lo
  1789. if hi < n:
  1790. for ind, x in enumerate(ps[hi+1:][::-1], 1):
  1791. if n-ind == x:
  1792. continue
  1793. perm[[x, n-ind]] = perm[[n-ind, x]]
  1794. if lo > 0:
  1795. for ind, x in enumerate(ps[:lo]):
  1796. if ind == x:
  1797. continue
  1798. perm[[x, ind]] = perm[[ind, x]]
  1799. if separate:
  1800. return B, (scaling, perm)
  1801. # get the inverse permutation
  1802. iperm = np.empty_like(perm)
  1803. iperm[perm] = np.arange(n)
  1804. return B, np.diag(scaling)[iperm, :]
  1805. def _validate_args_for_toeplitz_ops(c_or_cr, b, check_finite, keep_b_shape,
  1806. enforce_square=True):
  1807. """Validate arguments and format inputs for toeplitz functions
  1808. Parameters
  1809. ----------
  1810. c_or_cr : array_like or tuple of (array_like, array_like)
  1811. The vector ``c``, or a tuple of arrays (``c``, ``r``). Whatever the
  1812. actual shape of ``c``, it will be converted to a 1-D array. If not
  1813. supplied, ``r = conjugate(c)`` is assumed; in this case, if c[0] is
  1814. real, the Toeplitz matrix is Hermitian. r[0] is ignored; the first row
  1815. of the Toeplitz matrix is ``[c[0], r[1:]]``. Whatever the actual shape
  1816. of ``r``, it will be converted to a 1-D array.
  1817. b : (M,) or (M, K) array_like
  1818. Right-hand side in ``T x = b``.
  1819. check_finite : bool
  1820. Whether to check that the input matrices contain only finite numbers.
  1821. Disabling may give a performance gain, but may result in problems
  1822. (result entirely NaNs) if the inputs do contain infinities or NaNs.
  1823. keep_b_shape : bool
  1824. Whether to convert a (M,) dimensional b into a (M, 1) dimensional
  1825. matrix.
  1826. enforce_square : bool, optional
  1827. If True (default), this verifies that the Toeplitz matrix is square.
  1828. Returns
  1829. -------
  1830. r : array
  1831. 1d array corresponding to the first row of the Toeplitz matrix.
  1832. c: array
  1833. 1d array corresponding to the first column of the Toeplitz matrix.
  1834. b: array
  1835. (M,), (M, 1) or (M, K) dimensional array, post validation,
  1836. corresponding to ``b``.
  1837. dtype: numpy datatype
  1838. ``dtype`` stores the datatype of ``r``, ``c`` and ``b``. If any of
  1839. ``r``, ``c`` or ``b`` are complex, ``dtype`` is ``np.complex128``,
  1840. otherwise, it is ``np.float``.
  1841. b_shape: tuple
  1842. Shape of ``b`` after passing it through ``_asarray_validated``.
  1843. """
  1844. if isinstance(c_or_cr, tuple):
  1845. c, r = c_or_cr
  1846. c = _asarray_validated(c, check_finite=check_finite)
  1847. r = _asarray_validated(r, check_finite=check_finite)
  1848. else:
  1849. c = _asarray_validated(c_or_cr, check_finite=check_finite)
  1850. r = c.conjugate()
  1851. if b is None:
  1852. raise ValueError('`b` must be an array, not None.')
  1853. b = _asarray_validated(b, check_finite=check_finite)
  1854. b_shape = b.shape
  1855. is_not_square = r.shape[0] != c.shape[0]
  1856. if (enforce_square and is_not_square) or b.shape[0] != r.shape[0]:
  1857. raise ValueError('Incompatible dimensions.')
  1858. is_cmplx = np.iscomplexobj(r) or np.iscomplexobj(c) or np.iscomplexobj(b)
  1859. dtype = np.complex128 if is_cmplx else np.float64
  1860. r, c, b = (np.asarray(i, dtype=dtype) for i in (r, c, b))
  1861. if b.ndim == 1 and not keep_b_shape:
  1862. b = b.reshape(-1, 1)
  1863. elif b.ndim != 1:
  1864. b = b.reshape(b.shape[0], -1 if b.size > 0 else 0)
  1865. return r, c, b, dtype, b_shape
  1866. def matmul_toeplitz(c_or_cr, x, check_finite=False, workers=None):
  1867. r"""Efficient Toeplitz Matrix-Matrix Multiplication using FFT
  1868. This function returns the matrix multiplication between a Toeplitz
  1869. matrix and a dense matrix.
  1870. The Toeplitz matrix has constant diagonals, with c as its first column
  1871. and r as its first row. If r is not given, ``r == conjugate(c)`` is
  1872. assumed.
  1873. The documentation is written assuming array arguments are of specified
  1874. "core" shapes. However, array argument(s) of this function may have additional
  1875. "batch" dimensions prepended to the core shape. In this case, the array is treated
  1876. as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
  1877. Parameters
  1878. ----------
  1879. c_or_cr : array_like or tuple of (array_like, array_like)
  1880. The vector ``c``, or a tuple of arrays (``c``, ``r``). If not
  1881. supplied, ``r = conjugate(c)`` is assumed; in this case, if c[0] is
  1882. real, the Toeplitz matrix is Hermitian. r[0] is ignored; the first row
  1883. of the Toeplitz matrix is ``[c[0], r[1:]]``.
  1884. x : (M,) or (M, K) array_like
  1885. Matrix with which to multiply.
  1886. check_finite : bool, optional
  1887. Whether to check that the input matrices contain only finite numbers.
  1888. Disabling may give a performance gain, but may result in problems
  1889. (result entirely NaNs) if the inputs do contain infinities or NaNs.
  1890. workers : int, optional
  1891. To pass to scipy.fft.fft and ifft. Maximum number of workers to use
  1892. for parallel computation. If negative, the value wraps around from
  1893. ``os.cpu_count()``. See scipy.fft.fft for more details.
  1894. Returns
  1895. -------
  1896. T @ x : (M,) or (M, K) ndarray
  1897. The result of the matrix multiplication ``T @ x``. Shape of return
  1898. matches shape of `x`.
  1899. See Also
  1900. --------
  1901. toeplitz : Toeplitz matrix
  1902. solve_toeplitz : Solve a Toeplitz system using Levinson Recursion
  1903. Notes
  1904. -----
  1905. The Toeplitz matrix is embedded in a circulant matrix and the FFT is used
  1906. to efficiently calculate the matrix-matrix product.
  1907. Because the computation is based on the FFT, integer inputs will
  1908. result in floating point outputs. This is unlike NumPy's `matmul`,
  1909. which preserves the data type of the input.
  1910. This is partly based on the implementation that can be found in [1]_,
  1911. licensed under the MIT license. More information about the method can be
  1912. found in reference [2]_. References [3]_ and [4]_ have more reference
  1913. implementations in Python.
  1914. .. versionadded:: 1.6.0
  1915. References
  1916. ----------
  1917. .. [1] Jacob R Gardner, Geoff Pleiss, David Bindel, Kilian
  1918. Q Weinberger, Andrew Gordon Wilson, "GPyTorch: Blackbox Matrix-Matrix
  1919. Gaussian Process Inference with GPU Acceleration" with contributions
  1920. from Max Balandat and Ruihan Wu. Available online:
  1921. https://github.com/cornellius-gp/gpytorch
  1922. .. [2] J. Demmel, P. Koev, and X. Li, "A Brief Survey of Direct Linear
  1923. Solvers". In Z. Bai, J. Demmel, J. Dongarra, A. Ruhe, and H. van der
  1924. Vorst, editors. Templates for the Solution of Algebraic Eigenvalue
  1925. Problems: A Practical Guide. SIAM, Philadelphia, 2000. Available at:
  1926. http://www.netlib.org/utk/people/JackDongarra/etemplates/node384.html
  1927. .. [3] R. Scheibler, E. Bezzam, I. Dokmanic, Pyroomacoustics: A Python
  1928. package for audio room simulations and array processing algorithms,
  1929. Proc. IEEE ICASSP, Calgary, CA, 2018.
  1930. https://github.com/LCAV/pyroomacoustics/blob/pypi-release/
  1931. pyroomacoustics/adaptive/util.py
  1932. .. [4] Marano S, Edwards B, Ferrari G and Fah D (2017), "Fitting
  1933. Earthquake Spectra: Colored Noise and Incomplete Data", Bulletin of
  1934. the Seismological Society of America., January, 2017. Vol. 107(1),
  1935. pp. 276-291.
  1936. Examples
  1937. --------
  1938. Multiply the Toeplitz matrix T with matrix x::
  1939. [ 1 -1 -2 -3] [1 10]
  1940. T = [ 3 1 -1 -2] x = [2 11]
  1941. [ 6 3 1 -1] [2 11]
  1942. [10 6 3 1] [5 19]
  1943. To specify the Toeplitz matrix, only the first column and the first
  1944. row are needed.
  1945. >>> import numpy as np
  1946. >>> c = np.array([1, 3, 6, 10]) # First column of T
  1947. >>> r = np.array([1, -1, -2, -3]) # First row of T
  1948. >>> x = np.array([[1, 10], [2, 11], [2, 11], [5, 19]])
  1949. >>> from scipy.linalg import toeplitz, matmul_toeplitz
  1950. >>> matmul_toeplitz((c, r), x)
  1951. array([[-20., -80.],
  1952. [ -7., -8.],
  1953. [ 9., 85.],
  1954. [ 33., 218.]])
  1955. Check the result by creating the full Toeplitz matrix and
  1956. multiplying it by ``x``.
  1957. >>> toeplitz(c, r) @ x
  1958. array([[-20, -80],
  1959. [ -7, -8],
  1960. [ 9, 85],
  1961. [ 33, 218]])
  1962. The full matrix is never formed explicitly, so this routine
  1963. is suitable for very large Toeplitz matrices.
  1964. >>> n = 1000000
  1965. >>> matmul_toeplitz([1] + [0]*(n-1), np.ones(n))
  1966. array([1., 1., 1., ..., 1., 1., 1.], shape=(1000000,))
  1967. """
  1968. from ..fft import fft, ifft, rfft, irfft
  1969. c, r = c_or_cr if isinstance(c_or_cr, tuple) else (c_or_cr, np.conjugate(c_or_cr))
  1970. return _matmul_toepltiz(r, c, x, workers, check_finite, fft, ifft, rfft, irfft)
  1971. @_apply_over_batch(('r', 1), ('c', 1), ('x', '1|2'))
  1972. def _matmul_toepltiz(r, c, x, workers, check_finite, fft, ifft, rfft, irfft):
  1973. r, c, x, dtype, x_shape = _validate_args_for_toeplitz_ops((c, r), x, check_finite,
  1974. keep_b_shape=False,
  1975. enforce_square=False)
  1976. n, m = x.shape
  1977. T_nrows = len(c)
  1978. T_ncols = len(r)
  1979. p = T_nrows + T_ncols - 1 # equivalent to len(embedded_col)
  1980. return_shape = (T_nrows,) if len(x_shape) == 1 else (T_nrows, m)
  1981. # accommodate empty arrays
  1982. if x.size == 0:
  1983. return np.empty_like(x, shape=return_shape)
  1984. embedded_col = np.concatenate((c, r[-1:0:-1]))
  1985. if np.iscomplexobj(embedded_col) or np.iscomplexobj(x):
  1986. fft_mat = fft(embedded_col, axis=0, workers=workers).reshape(-1, 1)
  1987. fft_x = fft(x, n=p, axis=0, workers=workers)
  1988. mat_times_x = ifft(fft_mat*fft_x, axis=0,
  1989. workers=workers)[:T_nrows, :]
  1990. else:
  1991. # Real inputs; using rfft is faster
  1992. fft_mat = rfft(embedded_col, axis=0, workers=workers).reshape(-1, 1)
  1993. fft_x = rfft(x, n=p, axis=0, workers=workers)
  1994. mat_times_x = irfft(fft_mat*fft_x, axis=0,
  1995. workers=workers, n=p)[:T_nrows, :]
  1996. return mat_times_x.reshape(*return_shape)