_decomp_qr.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. """QR decomposition functions."""
  2. import numpy as np
  3. from scipy._lib._util import _apply_over_batch
  4. # Local imports
  5. from .lapack import get_lapack_funcs
  6. from ._misc import _datacopied
  7. __all__ = ['qr', 'qr_multiply', 'rq']
  8. def safecall(f, name, *args, **kwargs):
  9. """Call a LAPACK routine, determining lwork automatically and handling
  10. error return values"""
  11. lwork = kwargs.get("lwork", None)
  12. if lwork in (None, -1):
  13. kwargs['lwork'] = -1
  14. ret = f(*args, **kwargs)
  15. kwargs['lwork'] = ret[-2][0].real.astype(np.int_)
  16. ret = f(*args, **kwargs)
  17. if ret[-1] < 0:
  18. raise ValueError(f"illegal value in {-ret[-1]}th argument of internal {name}")
  19. return ret[:-2]
  20. @_apply_over_batch(('a', 2))
  21. def qr(a, overwrite_a=False, lwork=None, mode='full', pivoting=False,
  22. check_finite=True):
  23. """
  24. Compute QR decomposition of a matrix.
  25. Calculate the decomposition ``A = Q R`` where Q is unitary/orthogonal
  26. and R upper triangular.
  27. Parameters
  28. ----------
  29. a : (M, N) array_like
  30. Matrix to be decomposed
  31. overwrite_a : bool, optional
  32. Whether data in `a` is overwritten (may improve performance if
  33. `overwrite_a` is set to True by reusing the existing input data
  34. structure rather than creating a new one.)
  35. lwork : int, optional
  36. Work array size, lwork >= a.shape[1]. If None or -1, an optimal size
  37. is computed.
  38. mode : {'full', 'r', 'economic', 'raw'}, optional
  39. Determines what information is to be returned: either both Q and R
  40. ('full', default), only R ('r') or both Q and R but computed in
  41. economy-size ('economic', see Notes). The final option 'raw'
  42. (added in SciPy 0.11) makes the function return two matrices
  43. (Q, TAU) in the internal format used by LAPACK.
  44. pivoting : bool, optional
  45. Whether or not factorization should include pivoting for rank-revealing
  46. qr decomposition. If pivoting, compute the decomposition
  47. ``A[:, P] = Q @ R`` as above, but where P is chosen such that the
  48. diagonal of R is non-increasing. Equivalently, albeit less efficiently,
  49. an explicit P matrix may be formed explicitly by permuting the rows or columns
  50. (depending on the side of the equation on which it is to be used) of
  51. an identity matrix. See Examples.
  52. check_finite : bool, optional
  53. Whether to check that the input matrix contains only finite numbers.
  54. Disabling may give a performance gain, but may result in problems
  55. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  56. Returns
  57. -------
  58. Q : float or complex ndarray
  59. Of shape (M, M), or (M, K) for ``mode='economic'``. Not returned
  60. if ``mode='r'``. Replaced by tuple ``(Q, TAU)`` if ``mode='raw'``.
  61. R : float or complex ndarray
  62. Of shape (M, N), or (K, N) for ``mode in ['economic', 'raw']``.
  63. ``K = min(M, N)``.
  64. P : int ndarray
  65. Of shape (N,) for ``pivoting=True``. Not returned if
  66. ``pivoting=False``.
  67. Raises
  68. ------
  69. LinAlgError
  70. Raised if decomposition fails
  71. Notes
  72. -----
  73. This is an interface to the LAPACK routines dgeqrf, zgeqrf,
  74. dorgqr, zungqr, dgeqp3, and zgeqp3.
  75. If ``mode=economic``, the shapes of Q and R are (M, K) and (K, N) instead
  76. of (M,M) and (M,N), with ``K=min(M,N)``.
  77. Examples
  78. --------
  79. >>> import numpy as np
  80. >>> from scipy import linalg
  81. >>> rng = np.random.default_rng()
  82. >>> a = rng.standard_normal((9, 6))
  83. >>> q, r = linalg.qr(a)
  84. >>> np.allclose(a, np.dot(q, r))
  85. True
  86. >>> q.shape, r.shape
  87. ((9, 9), (9, 6))
  88. >>> r2 = linalg.qr(a, mode='r')
  89. >>> np.allclose(r, r2)
  90. True
  91. >>> q3, r3 = linalg.qr(a, mode='economic')
  92. >>> q3.shape, r3.shape
  93. ((9, 6), (6, 6))
  94. >>> q4, r4, p4 = linalg.qr(a, pivoting=True)
  95. >>> d = np.abs(np.diag(r4))
  96. >>> np.all(d[1:] <= d[:-1])
  97. True
  98. >>> np.allclose(a[:, p4], np.dot(q4, r4))
  99. True
  100. >>> P = np.eye(p4.size)[p4]
  101. >>> np.allclose(a, np.dot(q4, r4) @ P)
  102. True
  103. >>> np.allclose(a @ P.T, np.dot(q4, r4))
  104. True
  105. >>> q4.shape, r4.shape, p4.shape
  106. ((9, 9), (9, 6), (6,))
  107. >>> q5, r5, p5 = linalg.qr(a, mode='economic', pivoting=True)
  108. >>> q5.shape, r5.shape, p5.shape
  109. ((9, 6), (6, 6), (6,))
  110. >>> P = np.eye(6)[:, p5]
  111. >>> np.allclose(a @ P, np.dot(q5, r5))
  112. True
  113. """
  114. # 'qr' was the old default, equivalent to 'full'. Neither 'full' nor
  115. # 'qr' are used below.
  116. # 'raw' is used internally by qr_multiply
  117. if mode not in ['full', 'qr', 'r', 'economic', 'raw']:
  118. raise ValueError("Mode argument should be one of ['full', 'r', "
  119. "'economic', 'raw']")
  120. if check_finite:
  121. a1 = np.asarray_chkfinite(a)
  122. else:
  123. a1 = np.asarray(a)
  124. if len(a1.shape) != 2:
  125. raise ValueError("expected a 2-D array")
  126. M, N = a1.shape
  127. # accommodate empty arrays
  128. if a1.size == 0:
  129. K = min(M, N)
  130. if mode not in ['economic', 'raw']:
  131. Q = np.empty_like(a1, shape=(M, M))
  132. Q[...] = np.identity(M)
  133. R = np.empty_like(a1)
  134. else:
  135. Q = np.empty_like(a1, shape=(M, K))
  136. R = np.empty_like(a1, shape=(K, N))
  137. if pivoting:
  138. Rj = R, np.arange(N, dtype=np.int32)
  139. else:
  140. Rj = R,
  141. if mode == 'r':
  142. return Rj
  143. elif mode == 'raw':
  144. qr = np.empty_like(a1, shape=(M, N))
  145. tau = np.zeros_like(a1, shape=(K,))
  146. return ((qr, tau),) + Rj
  147. return (Q,) + Rj
  148. overwrite_a = overwrite_a or (_datacopied(a1, a))
  149. if pivoting:
  150. geqp3, = get_lapack_funcs(('geqp3',), (a1,))
  151. qr, jpvt, tau = safecall(geqp3, "geqp3", a1, overwrite_a=overwrite_a)
  152. jpvt -= 1 # geqp3 returns a 1-based index array, so subtract 1
  153. else:
  154. geqrf, = get_lapack_funcs(('geqrf',), (a1,))
  155. qr, tau = safecall(geqrf, "geqrf", a1, lwork=lwork,
  156. overwrite_a=overwrite_a)
  157. if mode not in ['economic', 'raw'] or M < N:
  158. R = np.triu(qr)
  159. else:
  160. R = np.triu(qr[:N, :])
  161. if pivoting:
  162. Rj = R, jpvt
  163. else:
  164. Rj = R,
  165. if mode == 'r':
  166. return Rj
  167. elif mode == 'raw':
  168. return ((qr, tau),) + Rj
  169. gor_un_gqr, = get_lapack_funcs(('orgqr',), (qr,))
  170. if M < N:
  171. Q, = safecall(gor_un_gqr, "gorgqr/gungqr", qr[:, :M], tau,
  172. lwork=lwork, overwrite_a=1)
  173. elif mode == 'economic':
  174. Q, = safecall(gor_un_gqr, "gorgqr/gungqr", qr, tau, lwork=lwork,
  175. overwrite_a=1)
  176. else:
  177. t = qr.dtype.char
  178. qqr = np.empty((M, M), dtype=t)
  179. qqr[:, :N] = qr
  180. Q, = safecall(gor_un_gqr, "gorgqr/gungqr", qqr, tau, lwork=lwork,
  181. overwrite_a=1)
  182. return (Q,) + Rj
  183. @_apply_over_batch(('a', 2), ('c', '1|2'))
  184. def qr_multiply(a, c, mode='right', pivoting=False, conjugate=False,
  185. overwrite_a=False, overwrite_c=False):
  186. """
  187. Calculate the QR decomposition and multiply Q with a matrix.
  188. Calculate the decomposition ``A = Q R`` where Q is unitary/orthogonal
  189. and R upper triangular. Multiply Q with a vector or a matrix c.
  190. Parameters
  191. ----------
  192. a : (M, N), array_like
  193. Input array
  194. c : array_like
  195. Input array to be multiplied by ``q``.
  196. mode : {'left', 'right'}, optional
  197. ``Q @ c`` is returned if mode is 'left', ``c @ Q`` is returned if
  198. mode is 'right'.
  199. The shape of c must be appropriate for the matrix multiplications,
  200. if mode is 'left', ``min(a.shape) == c.shape[0]``,
  201. if mode is 'right', ``a.shape[0] == c.shape[1]``.
  202. pivoting : bool, optional
  203. Whether or not factorization should include pivoting for rank-revealing
  204. qr decomposition, see the documentation of qr.
  205. conjugate : bool, optional
  206. Whether Q should be complex-conjugated. This might be faster
  207. than explicit conjugation.
  208. overwrite_a : bool, optional
  209. Whether data in a is overwritten (may improve performance)
  210. overwrite_c : bool, optional
  211. Whether data in c is overwritten (may improve performance).
  212. If this is used, c must be big enough to keep the result,
  213. i.e. ``c.shape[0]`` = ``a.shape[0]`` if mode is 'left'.
  214. Returns
  215. -------
  216. CQ : ndarray
  217. The product of ``Q`` and ``c``.
  218. R : (K, N), ndarray
  219. R array of the resulting QR factorization where ``K = min(M, N)``.
  220. P : (N,) ndarray
  221. Integer pivot array. Only returned when ``pivoting=True``.
  222. Raises
  223. ------
  224. LinAlgError
  225. Raised if QR decomposition fails.
  226. Notes
  227. -----
  228. This is an interface to the LAPACK routines ``?GEQRF``, ``?ORMQR``,
  229. ``?UNMQR``, and ``?GEQP3``.
  230. .. versionadded:: 0.11.0
  231. Examples
  232. --------
  233. >>> import numpy as np
  234. >>> from scipy.linalg import qr_multiply, qr
  235. >>> A = np.array([[1, 3, 3], [2, 3, 2], [2, 3, 3], [1, 3, 2]])
  236. >>> qc, r1, piv1 = qr_multiply(A, 2*np.eye(4), pivoting=1)
  237. >>> qc
  238. array([[-1., 1., -1.],
  239. [-1., -1., 1.],
  240. [-1., -1., -1.],
  241. [-1., 1., 1.]])
  242. >>> r1
  243. array([[-6., -3., -5. ],
  244. [ 0., -1., -1.11022302e-16],
  245. [ 0., 0., -1. ]])
  246. >>> piv1
  247. array([1, 0, 2], dtype=int32)
  248. >>> q2, r2, piv2 = qr(A, mode='economic', pivoting=1)
  249. >>> np.allclose(2*q2 - qc, np.zeros((4, 3)))
  250. True
  251. """
  252. if mode not in ['left', 'right']:
  253. raise ValueError("Mode argument can only be 'left' or 'right' but "
  254. f"not '{mode}'")
  255. c = np.asarray_chkfinite(c)
  256. if c.ndim < 2:
  257. onedim = True
  258. c = np.atleast_2d(c)
  259. if mode == "left":
  260. c = c.T
  261. else:
  262. onedim = False
  263. a = np.atleast_2d(np.asarray(a)) # chkfinite done in qr
  264. M, N = a.shape
  265. if mode == 'left':
  266. if c.shape[0] != min(M, N + overwrite_c*(M-N)):
  267. raise ValueError('Array shapes are not compatible for Q @ c'
  268. f' operation: {a.shape} vs {c.shape}')
  269. else:
  270. if M != c.shape[1]:
  271. raise ValueError('Array shapes are not compatible for c @ Q'
  272. f' operation: {c.shape} vs {a.shape}')
  273. raw = qr(a, overwrite_a, None, "raw", pivoting)
  274. Q, tau = raw[0]
  275. # accommodate empty arrays
  276. if c.size == 0:
  277. return (np.empty_like(c),) + raw[1:]
  278. gor_un_mqr, = get_lapack_funcs(('ormqr',), (Q,))
  279. if gor_un_mqr.typecode in ('s', 'd'):
  280. trans = "T"
  281. else:
  282. trans = "C"
  283. Q = Q[:, :min(M, N)]
  284. if M > N and mode == "left" and not overwrite_c:
  285. if conjugate:
  286. cc = np.zeros((c.shape[1], M), dtype=c.dtype, order="F")
  287. cc[:, :N] = c.T
  288. else:
  289. cc = np.zeros((M, c.shape[1]), dtype=c.dtype, order="F")
  290. cc[:N, :] = c
  291. trans = "N"
  292. if conjugate:
  293. lr = "R"
  294. else:
  295. lr = "L"
  296. overwrite_c = True
  297. elif c.flags["C_CONTIGUOUS"] and trans == "T" or conjugate:
  298. cc = c.T
  299. if mode == "left":
  300. lr = "R"
  301. else:
  302. lr = "L"
  303. else:
  304. trans = "N"
  305. cc = c
  306. if mode == "left":
  307. lr = "L"
  308. else:
  309. lr = "R"
  310. cQ, = safecall(gor_un_mqr, "gormqr/gunmqr", lr, trans, Q, tau, cc,
  311. overwrite_c=overwrite_c)
  312. if trans != "N":
  313. cQ = cQ.T
  314. if mode == "right":
  315. cQ = cQ[:, :min(M, N)]
  316. if onedim:
  317. cQ = cQ.ravel()
  318. return (cQ,) + raw[1:]
  319. @_apply_over_batch(('a', 2))
  320. def rq(a, overwrite_a=False, lwork=None, mode='full', check_finite=True):
  321. """
  322. Compute RQ decomposition of a matrix.
  323. Calculate the decomposition ``A = R Q`` where Q is unitary/orthogonal
  324. and R upper triangular.
  325. Parameters
  326. ----------
  327. a : (M, N) array_like
  328. Matrix to be decomposed
  329. overwrite_a : bool, optional
  330. Whether data in a is overwritten (may improve performance)
  331. lwork : int, optional
  332. Work array size, lwork >= a.shape[1]. If None or -1, an optimal size
  333. is computed.
  334. mode : {'full', 'r', 'economic'}, optional
  335. Determines what information is to be returned: either both Q and R
  336. ('full', default), only R ('r') or both Q and R but computed in
  337. economy-size ('economic', see Notes).
  338. check_finite : bool, optional
  339. Whether to check that the input matrix contains only finite numbers.
  340. Disabling may give a performance gain, but may result in problems
  341. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  342. Returns
  343. -------
  344. R : float or complex ndarray
  345. Of shape (M, N) or (M, K) for ``mode='economic'``. ``K = min(M, N)``.
  346. Q : float or complex ndarray
  347. Of shape (N, N) or (K, N) for ``mode='economic'``. Not returned
  348. if ``mode='r'``.
  349. Raises
  350. ------
  351. LinAlgError
  352. If decomposition fails.
  353. Notes
  354. -----
  355. This is an interface to the LAPACK routines sgerqf, dgerqf, cgerqf, zgerqf,
  356. sorgrq, dorgrq, cungrq and zungrq.
  357. If ``mode=economic``, the shapes of Q and R are (K, N) and (M, K) instead
  358. of (N,N) and (M,N), with ``K=min(M,N)``.
  359. Examples
  360. --------
  361. >>> import numpy as np
  362. >>> from scipy import linalg
  363. >>> rng = np.random.default_rng()
  364. >>> a = rng.standard_normal((6, 9))
  365. >>> r, q = linalg.rq(a)
  366. >>> np.allclose(a, r @ q)
  367. True
  368. >>> r.shape, q.shape
  369. ((6, 9), (9, 9))
  370. >>> r2 = linalg.rq(a, mode='r')
  371. >>> np.allclose(r, r2)
  372. True
  373. >>> r3, q3 = linalg.rq(a, mode='economic')
  374. >>> r3.shape, q3.shape
  375. ((6, 6), (6, 9))
  376. """
  377. if mode not in ['full', 'r', 'economic']:
  378. raise ValueError(
  379. "Mode argument should be one of ['full', 'r', 'economic']")
  380. if check_finite:
  381. a1 = np.asarray_chkfinite(a)
  382. else:
  383. a1 = np.asarray(a)
  384. if len(a1.shape) != 2:
  385. raise ValueError('expected matrix')
  386. M, N = a1.shape
  387. # accommodate empty arrays
  388. if a1.size == 0:
  389. K = min(M, N)
  390. if not mode == 'economic':
  391. R = np.empty_like(a1)
  392. Q = np.empty_like(a1, shape=(N, N))
  393. Q[...] = np.identity(N)
  394. else:
  395. R = np.empty_like(a1, shape=(M, K))
  396. Q = np.empty_like(a1, shape=(K, N))
  397. if mode == 'r':
  398. return R
  399. return R, Q
  400. overwrite_a = overwrite_a or (_datacopied(a1, a))
  401. gerqf, = get_lapack_funcs(('gerqf',), (a1,))
  402. rq, tau = safecall(gerqf, 'gerqf', a1, lwork=lwork,
  403. overwrite_a=overwrite_a)
  404. if not mode == 'economic' or N < M:
  405. R = np.triu(rq, N-M)
  406. else:
  407. R = np.triu(rq[-M:, -M:])
  408. if mode == 'r':
  409. return R
  410. gor_un_grq, = get_lapack_funcs(('orgrq',), (rq,))
  411. if N < M:
  412. Q, = safecall(gor_un_grq, "gorgrq/gungrq", rq[-N:], tau, lwork=lwork,
  413. overwrite_a=1)
  414. elif mode == 'economic':
  415. Q, = safecall(gor_un_grq, "gorgrq/gungrq", rq, tau, lwork=lwork,
  416. overwrite_a=1)
  417. else:
  418. rq1 = np.empty((N, N), dtype=rq.dtype)
  419. rq1[-M:] = rq
  420. Q, = safecall(gor_un_grq, "gorgrq/gungrq", rq1, tau, lwork=lwork,
  421. overwrite_a=1)
  422. return R, Q