test_batch.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. import inspect
  2. import pytest
  3. import numpy as np
  4. from numpy.testing import assert_allclose
  5. from scipy import linalg, sparse
  6. real_floating = [np.float32, np.float64]
  7. complex_floating = [np.complex64, np.complex128]
  8. floating = real_floating + complex_floating
  9. def get_random(shape, *, dtype, rng):
  10. A = rng.random(shape)
  11. if np.issubdtype(dtype, np.complexfloating):
  12. A = A + rng.random(shape) * 1j
  13. return A.astype(dtype)
  14. def get_nearly_hermitian(shape, dtype, atol, rng):
  15. # Generate a batch of nearly Hermitian matrices with specified
  16. # `shape` and `dtype`. `atol` controls the level of noise in
  17. # Hermitian-ness to by generated by `rng`.
  18. A = rng.random(shape).astype(dtype)
  19. At = np.conj(A.swapaxes(-1, -2))
  20. noise = rng.standard_normal(size=A.shape).astype(dtype) * atol
  21. return A + At + noise
  22. class TestBatch:
  23. # Test batch support for most linalg functions
  24. def batch_test(self, fun, arrays, *, core_dim=2, n_out=1, kwargs=None, dtype=None,
  25. broadcast=True, check_kwargs=True):
  26. # Check that all outputs of batched call `fun(A, **kwargs)` are the same
  27. # as if we loop over the separate vectors/matrices in `A`. Also check
  28. # that `fun` accepts `A` by position or keyword and that results are
  29. # identical. This is important because the name of the array argument
  30. # is manually specified to the decorator, and it's easy to mess up.
  31. # However, this makes it hard to test positional arguments passed
  32. # after the array, so we test that separately for a few functions to
  33. # make sure the decorator is working as it should.
  34. kwargs = {} if kwargs is None else kwargs
  35. parameters = list(inspect.signature(fun).parameters.keys())
  36. arrays = (arrays,) if not isinstance(arrays, tuple) else arrays
  37. # Identical results when passing argument by keyword or position
  38. res2 = fun(*arrays, **kwargs)
  39. if check_kwargs:
  40. res1 = fun(**dict(zip(parameters, arrays)), **kwargs)
  41. for out1, out2 in zip(res1, res2): # even a single array is iterable...
  42. np.testing.assert_equal(out1, out2)
  43. # Check results vs looping over
  44. res = (res2,) if n_out == 1 else res2
  45. # This is not the general behavior (only batch dimensions get
  46. # broadcasted by the decorator) but it's easier for testing.
  47. if broadcast:
  48. arrays = np.broadcast_arrays(*arrays)
  49. batch_shape = arrays[0].shape[:-core_dim]
  50. for i in range(batch_shape[0]):
  51. for j in range(batch_shape[1]):
  52. arrays_ij = (array[i, j] for array in arrays)
  53. ref = fun(*arrays_ij, **kwargs)
  54. ref = ((np.asarray(ref),) if n_out == 1 else
  55. tuple(np.asarray(refk) for refk in ref))
  56. for k in range(n_out):
  57. assert_allclose(res[k][i, j], ref[k])
  58. assert np.shape(res[k][i, j]) == ref[k].shape
  59. for k in range(len(ref)):
  60. out_dtype = ref[k].dtype if dtype is None else dtype
  61. assert res[k].dtype == out_dtype
  62. return res2 # return original, non-tuplized result
  63. @pytest.mark.parametrize('dtype', floating)
  64. def test_expm_cond(self, dtype):
  65. rng = np.random.default_rng(8342310302941288912051)
  66. A = rng.random((5, 3, 4, 4)).astype(dtype)
  67. self.batch_test(linalg.expm_cond, A)
  68. @pytest.mark.parametrize('dtype', floating)
  69. def test_issymmetric(self, dtype):
  70. rng = np.random.default_rng(8342310302941288912051)
  71. A = get_nearly_hermitian((5, 3, 4, 4), dtype, 3e-4, rng)
  72. res = self.batch_test(linalg.issymmetric, A, kwargs=dict(atol=1e-3))
  73. assert not np.all(res) # ensure test is not trivial: not all True or False;
  74. assert np.any(res) # also confirms that `atol` is passed to issymmetric
  75. @pytest.mark.parametrize('dtype', floating)
  76. def test_ishermitian(self, dtype):
  77. rng = np.random.default_rng(8342310302941288912051)
  78. A = get_nearly_hermitian((5, 3, 4, 4), dtype, 3e-4, rng)
  79. res = self.batch_test(linalg.ishermitian, A, kwargs=dict(atol=1e-3))
  80. assert not np.all(res) # ensure test is not trivial: not all True or False;
  81. assert np.any(res) # also confirms that `atol` is passed to ishermitian
  82. @pytest.mark.parametrize('dtype', floating)
  83. def test_diagsvd(self, dtype):
  84. rng = np.random.default_rng(8342310302941288912051)
  85. A = rng.random((5, 3, 4)).astype(dtype)
  86. res1 = self.batch_test(linalg.diagsvd, A, kwargs=dict(M=6, N=4), core_dim=1)
  87. # test that `M, N` can be passed by position
  88. res2 = linalg.diagsvd(A, 6, 4)
  89. np.testing.assert_equal(res1, res2)
  90. @pytest.mark.parametrize('fun', [linalg.inv, linalg.sqrtm, linalg.signm,
  91. linalg.sinm, linalg.cosm, linalg.tanhm,
  92. linalg.sinhm, linalg.coshm, linalg.tanhm,
  93. linalg.pinv, linalg.pinvh, linalg.orth])
  94. @pytest.mark.parametrize('dtype', floating)
  95. def test_matmat(self, fun, dtype): # matrix in, matrix out
  96. rng = np.random.default_rng(8342310302941288912051)
  97. A = get_random((5, 3, 4, 4), dtype=dtype, rng=rng)
  98. # sqrtm can return complex output for real input resulting in i/o type
  99. # mismatch. Nudge the eigenvalues to positive side to avoid this.
  100. if fun == linalg.sqrtm:
  101. A = A + 3*np.eye(4, dtype=dtype)
  102. self.batch_test(fun, A)
  103. @pytest.mark.parametrize('dtype', floating)
  104. def test_null_space(self, dtype):
  105. rng = np.random.default_rng(8342310302941288912051)
  106. A = get_random((5, 3, 4, 6), dtype=dtype, rng=rng)
  107. self.batch_test(linalg.null_space, A)
  108. @pytest.mark.parametrize('dtype', floating)
  109. def test_funm(self, dtype):
  110. rng = np.random.default_rng(8342310302941288912051)
  111. A = get_random((2, 4, 3, 3), dtype=dtype, rng=rng)
  112. self.batch_test(linalg.funm, A, kwargs=dict(func=np.sin))
  113. @pytest.mark.parametrize('dtype', floating)
  114. def test_fractional_matrix_power(self, dtype):
  115. rng = np.random.default_rng(8342310302941288912051)
  116. A = get_random((2, 4, 3, 3), dtype=dtype, rng=rng)
  117. res1 = self.batch_test(linalg.fractional_matrix_power, A, kwargs={'t':1.5})
  118. # test that `t` can be passed by position
  119. res2 = linalg.fractional_matrix_power(A, 1.5)
  120. np.testing.assert_equal(res1, res2)
  121. @pytest.mark.parametrize('dtype', floating)
  122. def test_logm(self, dtype):
  123. # One test failed absolute tolerance with default random seed
  124. rng = np.random.default_rng(89940026998903887141749720079406074936)
  125. A = get_random((5, 3, 4, 4), dtype=dtype, rng=rng)
  126. A = A + 3*np.eye(4) # avoid complex output for real input
  127. res1 = self.batch_test(linalg.logm, A)
  128. # test that `disp` can be passed by position
  129. res2 = linalg.logm(A)
  130. for res1i, res2i in zip(res1, res2):
  131. np.testing.assert_equal(res1i, res2i)
  132. @pytest.mark.parametrize('dtype', floating)
  133. def test_pinv(self, dtype):
  134. rng = np.random.default_rng(8342310302941288912051)
  135. A = get_random((5, 3, 4, 4), dtype=dtype, rng=rng)
  136. self.batch_test(linalg.pinv, A, n_out=2, kwargs=dict(return_rank=True))
  137. @pytest.mark.parametrize('dtype', floating)
  138. def test_matrix_balance(self, dtype):
  139. rng = np.random.default_rng(8342310302941288912051)
  140. A = get_random((5, 3, 4, 4), dtype=dtype, rng=rng)
  141. self.batch_test(linalg.matrix_balance, A, n_out=2)
  142. self.batch_test(linalg.matrix_balance, A, n_out=2, kwargs={'separate':True})
  143. @pytest.mark.parametrize('dtype', floating)
  144. def test_bandwidth(self, dtype):
  145. rng = np.random.default_rng(8342310302941288912051)
  146. A = get_random((4, 4), dtype=dtype, rng=rng)
  147. A = np.asarray([np.triu(A, k) for k in range(-3, 3)]).reshape((2, 3, 4, 4))
  148. self.batch_test(linalg.bandwidth, A, n_out=2)
  149. @pytest.mark.parametrize('fun_n_out', [(linalg.cholesky, 1), (linalg.ldl, 3),
  150. (linalg.cho_factor, 2)])
  151. @pytest.mark.parametrize('dtype', floating)
  152. def test_ldl_cholesky(self, fun_n_out, dtype):
  153. rng = np.random.default_rng(8342310302941288912051)
  154. fun, n_out = fun_n_out
  155. A = get_nearly_hermitian((5, 3, 4, 4), dtype, 0, rng) # exactly Hermitian
  156. A = A + 4*np.eye(4, dtype=dtype) # ensure positive definite for Cholesky
  157. self.batch_test(fun, A, n_out=n_out)
  158. @pytest.mark.parametrize('compute_uv', [False, True])
  159. @pytest.mark.parametrize('dtype', floating)
  160. def test_svd(self, compute_uv, dtype):
  161. rng = np.random.default_rng(8342310302941288912051)
  162. A = get_random((5, 3, 2, 4), dtype=dtype, rng=rng)
  163. n_out = 3 if compute_uv else 1
  164. self.batch_test(linalg.svd, A, n_out=n_out, kwargs=dict(compute_uv=compute_uv))
  165. @pytest.mark.parametrize('fun', [linalg.polar, linalg.qr, linalg.rq])
  166. @pytest.mark.parametrize('dtype', floating)
  167. def test_polar_qr_rq(self, fun, dtype):
  168. rng = np.random.default_rng(8342310302941288912051)
  169. A = get_random((5, 3, 2, 4), dtype=dtype, rng=rng)
  170. self.batch_test(fun, A, n_out=2)
  171. @pytest.mark.parametrize('cdim', [(5,), (5, 4), (2, 3, 5, 4)])
  172. @pytest.mark.parametrize('dtype', floating)
  173. def test_qr_multiply(self, cdim, dtype):
  174. rng = np.random.default_rng(8342310302941288912051)
  175. A = get_random((2, 3, 5, 5), dtype=dtype, rng=rng)
  176. c = get_random(cdim, dtype=dtype, rng=rng)
  177. res = linalg.qr_multiply(A, c, mode='left')
  178. q, r = linalg.qr(A)
  179. ref = q @ c
  180. atol = 1e-6 if dtype in {np.float32, np.complex64} else 1e-12
  181. assert_allclose(res[0], ref, atol=atol)
  182. assert_allclose(res[1], r, atol=atol)
  183. @pytest.mark.parametrize('uvdim', [[(5,), (3,)], [(4, 5, 2), (4, 3, 2)]])
  184. @pytest.mark.parametrize('dtype', floating)
  185. def test_qr_update(self, uvdim, dtype):
  186. rng = np.random.default_rng(8342310302941288912051)
  187. udim, vdim = uvdim
  188. A = get_random((4, 5, 3), dtype=dtype, rng=rng)
  189. u = get_random(udim, dtype=dtype, rng=rng)
  190. v = get_random(vdim, dtype=dtype, rng=rng)
  191. q, r = linalg.qr(A)
  192. res = linalg.qr_update(q, r, u, v)
  193. for i in range(4):
  194. qi, ri = q[i], r[i]
  195. ui, vi = (u, v) if u.ndim == 1 else (u[i], v[i])
  196. ref_i = linalg.qr_update(qi, ri, ui, vi)
  197. assert_allclose(res[0][i], ref_i[0])
  198. assert_allclose(res[1][i], ref_i[1])
  199. @pytest.mark.parametrize('udim', [(5,), (4, 3, 5)])
  200. @pytest.mark.parametrize('kdim', [(), (4,)])
  201. @pytest.mark.parametrize('dtype', floating)
  202. def test_qr_insert(self, udim, kdim, dtype):
  203. rng = np.random.default_rng(8342310302941288912051)
  204. A = get_random((4, 5, 5), dtype=dtype, rng=rng)
  205. u = get_random(udim, dtype=dtype, rng=rng)
  206. k = rng.integers(0, 5, size=kdim)
  207. q, r = linalg.qr(A)
  208. res = linalg.qr_insert(q, r, u, k)
  209. for i in range(4):
  210. qi, ri = q[i], r[i]
  211. ki = k if k.ndim == 0 else k[i]
  212. ui = u if u.ndim == 1 else u[i]
  213. ref_i = linalg.qr_insert(qi, ri, ui, ki)
  214. assert_allclose(res[0][i], ref_i[0])
  215. assert_allclose(res[1][i], ref_i[1])
  216. @pytest.mark.parametrize('kdim', [(), (4,)])
  217. @pytest.mark.parametrize('dtype', floating)
  218. def test_qr_delete(self, kdim, dtype):
  219. rng = np.random.default_rng(8342310302941288912051)
  220. A = get_random((4, 5, 5), dtype=dtype, rng=rng)
  221. k = rng.integers(0, 4, size=kdim)
  222. q, r = linalg.qr(A)
  223. res = linalg.qr_delete(q, r, k)
  224. for i in range(4):
  225. qi, ri = q[i], r[i]
  226. ki = k if k.ndim == 0 else k[i]
  227. ref_i = linalg.qr_delete(qi, ri, ki)
  228. assert_allclose(res[0][i], ref_i[0])
  229. assert_allclose(res[1][i], ref_i[1])
  230. @pytest.mark.parametrize('fun', [linalg.schur, linalg.lu_factor])
  231. @pytest.mark.parametrize('dtype', floating)
  232. def test_schur_lu(self, fun, dtype):
  233. rng = np.random.default_rng(8342310302941288912051)
  234. A = get_random((5, 3, 4, 4), dtype=dtype, rng=rng)
  235. self.batch_test(fun, A, n_out=2)
  236. @pytest.mark.parametrize('calc_q', [False, True])
  237. @pytest.mark.parametrize('dtype', floating)
  238. def test_hessenberg(self, calc_q, dtype):
  239. rng = np.random.default_rng(8342310302941288912051)
  240. A = get_random((5, 3, 4, 4), dtype=dtype, rng=rng)
  241. n_out = 2 if calc_q else 1
  242. self.batch_test(linalg.hessenberg, A, n_out=n_out, kwargs=dict(calc_q=calc_q))
  243. @pytest.mark.parametrize('eigvals_only', [False, True])
  244. @pytest.mark.parametrize('dtype', floating)
  245. def test_eig_banded(self, eigvals_only, dtype):
  246. rng = np.random.default_rng(8342310302941288912051)
  247. A = get_random((5, 3, 4, 4), dtype=dtype, rng=rng)
  248. n_out = 1 if eigvals_only else 2
  249. self.batch_test(linalg.eig_banded, A, n_out=n_out,
  250. kwargs=dict(eigvals_only=eigvals_only))
  251. @pytest.mark.parametrize('dtype', floating)
  252. def test_eigvals_banded(self, dtype):
  253. rng = np.random.default_rng(8342310302941288912051)
  254. A = get_random((5, 3, 4, 4), dtype=dtype, rng=rng)
  255. self.batch_test(linalg.eigvals_banded, A)
  256. @pytest.mark.parametrize('two_in', [False, True])
  257. @pytest.mark.parametrize('fun_n_nout', [(linalg.eigh, 1), (linalg.eigh, 2),
  258. (linalg.eigvalsh, 1), (linalg.eigvals, 1)])
  259. @pytest.mark.parametrize('dtype', floating)
  260. def test_eigh(self, two_in, fun_n_nout, dtype):
  261. rng = np.random.default_rng(8342310302941288912051)
  262. fun, n_out = fun_n_nout
  263. A = get_nearly_hermitian((1, 3, 4, 4), dtype, 0, rng) # exactly Hermitian
  264. B = get_nearly_hermitian((2, 1, 4, 4), dtype, 0, rng) # exactly Hermitian
  265. B = B + 4*np.eye(4).astype(dtype) # needs to be positive definite
  266. args = (A, B) if two_in else (A,)
  267. kwargs = dict(eigvals_only=True) if (n_out == 1 and fun==linalg.eigh) else {}
  268. self.batch_test(fun, args, n_out=n_out, kwargs=kwargs)
  269. @pytest.mark.parametrize('compute_expm', [False, True])
  270. @pytest.mark.parametrize('dtype', floating)
  271. def test_expm_frechet(self, compute_expm, dtype):
  272. rng = np.random.default_rng(8342310302941288912051)
  273. A = get_random((1, 3, 4, 4), dtype=dtype, rng=rng)
  274. E = get_random((2, 1, 4, 4), dtype=dtype, rng=rng)
  275. n_out = 2 if compute_expm else 1
  276. self.batch_test(linalg.expm_frechet, (A, E), n_out=n_out,
  277. kwargs=dict(compute_expm=compute_expm))
  278. @pytest.mark.parametrize('dtype', floating)
  279. def test_subspace_angles(self, dtype):
  280. rng = np.random.default_rng(8342310302941288912051)
  281. A = get_random((1, 3, 4, 3), dtype=dtype, rng=rng)
  282. B = get_random((2, 1, 4, 3), dtype=dtype, rng=rng)
  283. self.batch_test(linalg.subspace_angles, (A, B))
  284. # just to show that A and B don't need to be broadcastable
  285. M, N, K = 4, 5, 3
  286. A = get_random((1, 3, M, N), dtype=dtype, rng=rng)
  287. B = get_random((2, 1, M, K), dtype=dtype, rng=rng)
  288. assert linalg.subspace_angles(A, B).shape == (2, 3, min(N, K))
  289. @pytest.mark.parametrize('fun', [linalg.svdvals])
  290. @pytest.mark.parametrize('dtype', floating)
  291. def test_svdvals(self, fun, dtype):
  292. rng = np.random.default_rng(8342310302941288912051)
  293. A = get_random((2, 3, 4, 5), dtype=dtype, rng=rng)
  294. self.batch_test(fun, A)
  295. @pytest.mark.parametrize('fun_n_out', [(linalg.orthogonal_procrustes, 2),
  296. (linalg.khatri_rao, 1),
  297. (linalg.solve_continuous_lyapunov, 1),
  298. (linalg.solve_discrete_lyapunov, 1),
  299. (linalg.qz, 4),
  300. (linalg.ordqz, 6)])
  301. @pytest.mark.parametrize('dtype', floating)
  302. def test_two_generic_matrix_inputs(self, fun_n_out, dtype):
  303. rng = np.random.default_rng(8342310302941288912051)
  304. fun, n_out = fun_n_out
  305. A = get_random((2, 3, 4, 4), dtype=dtype, rng=rng)
  306. B = get_random((2, 3, 4, 4), dtype=dtype, rng=rng)
  307. self.batch_test(fun, (A, B), n_out=n_out)
  308. @pytest.mark.parametrize('dtype', floating)
  309. def test_cossin(self, dtype):
  310. rng = np.random.default_rng(8342310302941288912051)
  311. p, q = 3, 4
  312. X = get_random((2, 3, 10, 10), dtype=dtype, rng=rng)
  313. x11, x12, x21, x22 = (X[..., :p, :q], X[..., :p, q:],
  314. X[..., p:, :q], X[..., p:, q:])
  315. res = linalg.cossin(X, p, q)
  316. ref = linalg.cossin((x11, x12, x21, x22))
  317. for res_i, ref_i in zip(res, ref):
  318. np.testing.assert_equal(res_i, ref_i)
  319. for j in range(2):
  320. for k in range(3):
  321. ref_jk = linalg.cossin(X[j, k], p, q)
  322. for res_i, ref_ijk in zip(res, ref_jk):
  323. np.testing.assert_equal(res_i[j, k], ref_ijk)
  324. @pytest.mark.parametrize('dtype', floating)
  325. def test_sylvester(self, dtype):
  326. rng = np.random.default_rng(8342310302941288912051)
  327. A = get_random((2, 3, 5, 5), dtype=dtype, rng=rng)
  328. B = get_random((2, 3, 5, 5), dtype=dtype, rng=rng)
  329. C = get_random((2, 3, 5, 5), dtype=dtype, rng=rng)
  330. self.batch_test(linalg.solve_sylvester, (A, B, C))
  331. @pytest.mark.parametrize('fun', [linalg.solve_continuous_are,
  332. linalg.solve_discrete_are])
  333. @pytest.mark.parametrize('dtype', floating)
  334. def test_are(self, fun, dtype):
  335. rng = np.random.default_rng(8342310302941288912051)
  336. a = get_random((2, 3, 5, 5), dtype=dtype, rng=rng)
  337. b = get_random((2, 3, 5, 5), dtype=dtype, rng=rng)
  338. q = get_nearly_hermitian((2, 3, 5, 5), dtype=dtype, atol=0, rng=rng)
  339. r = get_nearly_hermitian((2, 3, 5, 5), dtype=dtype, atol=0, rng=rng)
  340. a = a + 5*np.eye(5) # making these positive definite seems to help
  341. b = b + 5*np.eye(5)
  342. q = q + 5*np.eye(5)
  343. r = r + 5*np.eye(5)
  344. e = np.eye(5)
  345. s = np.zeros((5, 5))
  346. self.batch_test(fun, (a, b, q, r))
  347. self.batch_test(fun, (a, b, q, r, e))
  348. self.batch_test(fun, (a, b, q, r, e, s))
  349. res = fun(a, b, q, r)
  350. ref = fun(a, b, q, r, s=s)
  351. np.testing.assert_allclose(res, ref)
  352. @pytest.mark.parametrize('dtype', floating)
  353. def test_rsf2cs(self, dtype):
  354. rng = np.random.default_rng(8342310302941288912051)
  355. A = get_random((2, 3, 4, 4), dtype=dtype, rng=rng)
  356. T, Z = linalg.schur(A)
  357. self.batch_test(linalg.rsf2csf, (T, Z), n_out=2)
  358. @pytest.mark.parametrize('dtype', floating)
  359. def test_cholesky_banded(self, dtype):
  360. rng = np.random.default_rng(8342310302941288912051)
  361. ab = get_random((5, 4, 3, 6), dtype=dtype, rng=rng)
  362. ab[..., -1, :] = 10 # make diagonal dominant
  363. self.batch_test(linalg.cholesky_banded, ab)
  364. @pytest.mark.parametrize('dtype', floating)
  365. def test_block_diag(self, dtype):
  366. rng = np.random.default_rng(8342310302941288912051)
  367. a = get_random((1, 3, 1, 3), dtype=dtype, rng=rng)
  368. b = get_random((2, 1, 3, 6), dtype=dtype, rng=rng)
  369. c = get_random((1, 1, 3, 2), dtype=dtype, rng=rng)
  370. # batch_test doesn't have the logic to broadcast just the batch shapes,
  371. # so do it manually.
  372. a2 = np.broadcast_to(a, (2, 3, 1, 3))
  373. b2 = np.broadcast_to(b, (2, 3, 3, 6))
  374. c2 = np.broadcast_to(c, (2, 3, 3, 2))
  375. ref = self.batch_test(linalg.block_diag, (a2, b2, c2),
  376. check_kwargs=False, broadcast=False)
  377. # Check that `block_diag` broadcasts the batch shapes as expected.
  378. res = linalg.block_diag(a, b, c)
  379. assert_allclose(res, ref)
  380. @pytest.mark.parametrize('fun_n_out', [(linalg.eigh_tridiagonal, 2),
  381. (linalg.eigvalsh_tridiagonal, 1)])
  382. @pytest.mark.parametrize('dtype', real_floating)
  383. # "Only real arrays currently supported"
  384. def test_eigh_tridiagonal(self, fun_n_out, dtype):
  385. rng = np.random.default_rng(8342310302941288912051)
  386. fun, n_out = fun_n_out
  387. d = get_random((3, 4, 5), dtype=dtype, rng=rng)
  388. e = get_random((3, 4, 4), dtype=dtype, rng=rng)
  389. self.batch_test(fun, (d, e), core_dim=1, n_out=n_out, broadcast=False)
  390. @pytest.mark.parametrize('bdim', [(5,), (5, 4), (2, 3, 5, 4)])
  391. @pytest.mark.parametrize('dtype', floating)
  392. def test_solve(self, bdim, dtype):
  393. rng = np.random.default_rng(8342310302941288912051)
  394. A = get_random((2, 3, 5, 5), dtype=dtype, rng=rng)
  395. b = get_random(bdim, dtype=dtype, rng=rng)
  396. x = linalg.solve(A, b)
  397. if len(bdim) == 1:
  398. x = x[..., np.newaxis]
  399. b = b[..., np.newaxis]
  400. assert_allclose(A @ x - b, 0, atol=2e-6)
  401. assert_allclose(x, np.linalg.solve(A, b), atol=3e-6)
  402. @pytest.mark.parametrize('bdim', [(5,), (5, 4), (2, 3, 5, 4)])
  403. @pytest.mark.parametrize('dtype', floating)
  404. def test_lu_solve(self, bdim, dtype):
  405. rng = np.random.default_rng(8342310302941288912051)
  406. A = get_random((2, 3, 5, 5), dtype=dtype, rng=rng)
  407. b = get_random(bdim, dtype=dtype, rng=rng)
  408. lu_and_piv = linalg.lu_factor(A)
  409. x = linalg.lu_solve(lu_and_piv, b)
  410. if len(bdim) == 1:
  411. x = x[..., np.newaxis]
  412. b = b[..., np.newaxis]
  413. assert_allclose(A @ x - b, 0, atol=2e-6)
  414. assert_allclose(x, np.linalg.solve(A, b), atol=3e-6)
  415. @pytest.mark.parametrize('l_and_u', [(1, 1), ([2, 1, 0], [0, 1 , 2])])
  416. @pytest.mark.parametrize('bdim', [(5,), (5, 4), (2, 3, 5, 4)])
  417. @pytest.mark.parametrize('dtype', floating)
  418. def test_solve_banded(self, l_and_u, bdim, dtype):
  419. rng = np.random.default_rng(8342310302941288912051)
  420. l, u = l_and_u
  421. ab = get_random((2, 3, 3, 5), dtype=dtype, rng=rng)
  422. b = get_random(bdim, dtype=dtype, rng=rng)
  423. x = linalg.solve_banded((l, u), ab, b)
  424. for i in range(2):
  425. for j in range(3):
  426. bij = b if len(bdim) <= 2 else b[i, j]
  427. lj = l if np.ndim(l) == 0 else l[j]
  428. uj = u if np.ndim(u) == 0 else u[j]
  429. xij = linalg.solve_banded((lj, uj), ab[i, j], bij)
  430. assert_allclose(x[i, j], xij)
  431. @pytest.mark.parametrize('separate_r', [False, True])
  432. @pytest.mark.parametrize('bdim', [(5,), (5, 4), (2, 3, 5, 4)])
  433. @pytest.mark.parametrize('dtype', floating)
  434. def test_solve_toeplitz(self, separate_r, bdim, dtype):
  435. rng = np.random.default_rng(8342310302941288912051)
  436. c = get_random((2, 3, 5), dtype=dtype, rng=rng)
  437. r = get_random((2, 3, 5), dtype=dtype, rng=rng)
  438. c_or_cr = (c, r) if separate_r else c
  439. b = get_random(bdim, dtype=dtype, rng=rng)
  440. x = linalg.solve_toeplitz(c_or_cr, b)
  441. for i in range(2):
  442. for j in range(3):
  443. bij = b if len(bdim) <= 2 else b[i, j]
  444. c_or_cr_ij = (c[i, j], r[i, j]) if separate_r else c[i, j]
  445. xij = linalg.solve_toeplitz(c_or_cr_ij, bij)
  446. assert_allclose(x[i, j], xij)
  447. @pytest.mark.parametrize('separate_r', [False, True])
  448. @pytest.mark.parametrize('xdim', [(5,), (5, 4), (2, 3, 5, 4)])
  449. @pytest.mark.parametrize('dtype', floating)
  450. def test_matmul_toeplitz(self, separate_r, xdim, dtype):
  451. rng = np.random.default_rng(8342310302941288912051)
  452. c = get_random((2, 3, 5), dtype=dtype, rng=rng)
  453. r = get_random((2, 3, 5), dtype=dtype, rng=rng)
  454. c_or_cr = (c, r) if separate_r else c
  455. x = get_random(xdim, dtype=dtype, rng=rng)
  456. res = linalg.matmul_toeplitz(c_or_cr, x)
  457. if separate_r:
  458. ref = linalg.toeplitz(c, r) @ x
  459. else:
  460. ref = linalg.toeplitz(c) @ x
  461. atol = 1e-6 if dtype in {np.float32, np.complex64} else 1e-12
  462. assert_allclose(res, ref, atol=atol)
  463. @pytest.mark.parametrize('bdim', [(5,), (5, 4), (2, 3, 5, 4)])
  464. @pytest.mark.parametrize('dtype', floating)
  465. def test_cho_solve(self, bdim, dtype):
  466. rng = np.random.default_rng(8342310302941288912051)
  467. A = get_nearly_hermitian((2, 3, 5, 5), dtype=dtype, atol=0, rng=rng)
  468. A = A + 5*np.eye(5)
  469. c_and_lower = linalg.cho_factor(A)
  470. b = get_random(bdim, dtype=dtype, rng=rng)
  471. x = linalg.cho_solve(c_and_lower, b)
  472. if len(bdim) == 1:
  473. x = x[..., np.newaxis]
  474. b = b[..., np.newaxis]
  475. assert_allclose(A @ x - b, 0, atol=1e-6)
  476. assert_allclose(x, np.linalg.solve(A, b), atol=2e-6)
  477. @pytest.mark.parametrize('lower', [False, True])
  478. @pytest.mark.parametrize('bdim', [(5,), (5, 4), (2, 3, 5, 4)])
  479. @pytest.mark.parametrize('dtype', floating)
  480. def test_cho_solve_banded(self, lower, bdim, dtype):
  481. rng = np.random.default_rng(8342310302941288912051)
  482. A = get_random((2, 3, 3, 5), dtype=dtype, rng=rng)
  483. row_diag = 0 if lower else -1
  484. A[:, :, row_diag] = 10
  485. cb = linalg.cholesky_banded(A, lower=lower)
  486. b = get_random(bdim, dtype=dtype, rng=rng)
  487. x = linalg.cho_solve_banded((cb, lower), b)
  488. for i in range(2):
  489. for j in range(3):
  490. bij = b if len(bdim) <= 2 else b[i, j]
  491. xij = linalg.cho_solve_banded((cb[i, j], lower), bij)
  492. assert_allclose(x[i, j], xij)
  493. @pytest.mark.parametrize('bdim', [(5,), (5, 4), (2, 3, 5, 4)])
  494. @pytest.mark.parametrize('dtype', floating)
  495. def test_solveh_banded(self, bdim, dtype):
  496. rng = np.random.default_rng(8342310302941288912051)
  497. A = get_random((2, 3, 3, 5), dtype=dtype, rng=rng)
  498. A[:, :, -1] = 10
  499. b = get_random(bdim, dtype=dtype, rng=rng)
  500. x = linalg.solveh_banded(A, b)
  501. for i in range(2):
  502. for j in range(3):
  503. bij = b if len(bdim) <= 2 else b[i, j]
  504. xij = linalg.solveh_banded(A[i, j], bij)
  505. assert_allclose(x[i, j], xij)
  506. @pytest.mark.parametrize('bdim', [(5,), (5, 4), (2, 3, 5, 4)])
  507. @pytest.mark.parametrize('dtype', floating)
  508. def test_solve_triangular(self, bdim, dtype):
  509. rng = np.random.default_rng(8342310302941288912051)
  510. A = get_random((2, 3, 5, 5), dtype=dtype, rng=rng)
  511. A = np.tril(A)
  512. b = get_random(bdim, dtype=dtype, rng=rng)
  513. x = linalg.solve_triangular(A, b, lower=True)
  514. if len(bdim) == 1:
  515. x = x[..., np.newaxis]
  516. b = b[..., np.newaxis]
  517. atol = 1e-10 if dtype in (np.complex128, np.float64) else 2e-4
  518. assert_allclose(A @ x - b, 0, atol=atol)
  519. assert_allclose(x, np.linalg.solve(A, b), atol=5*atol)
  520. @pytest.mark.parametrize('bdim', [(4,), (4, 3), (2, 3, 4, 3)])
  521. @pytest.mark.parametrize('dtype', floating)
  522. def test_lstsq(self, bdim, dtype):
  523. rng = np.random.default_rng(8342310302941288912051)
  524. A = get_random((2, 3, 4, 5), dtype=dtype, rng=rng)
  525. b = get_random(bdim, dtype=dtype, rng=rng)
  526. res = linalg.lstsq(A, b)
  527. x = res[0]
  528. if len(bdim) == 1:
  529. x = x[..., np.newaxis]
  530. b = b[..., np.newaxis]
  531. assert_allclose(A @ x - b, 0, atol=2e-6)
  532. assert len(res) == 4
  533. @pytest.mark.parametrize('dtype', floating)
  534. def test_clarkson_woodruff_transform(self, dtype):
  535. rng = np.random.default_rng(8342310302941288912051)
  536. A = get_random((5, 3, 4, 6), dtype=dtype, rng=rng)
  537. self.batch_test(linalg.clarkson_woodruff_transform, A,
  538. kwargs=dict(sketch_size=3, rng=311224))
  539. def test_clarkson_woodruff_transform_sparse(self):
  540. rng = np.random.default_rng(8342310302941288912051)
  541. A = get_random((5, 3, 4, 6), dtype=np.float64, rng=rng)
  542. A = sparse.coo_array(A)
  543. message = "Batch support for sparse arrays is not available."
  544. with pytest.raises(NotImplementedError, match=message):
  545. linalg.clarkson_woodruff_transform(A, sketch_size=3, rng=rng)
  546. @pytest.mark.parametrize('f, args', [
  547. (linalg.toeplitz, (np.ones((0, 4)),)),
  548. (linalg.eig, (np.ones((3, 0, 5, 5)),)),
  549. ])
  550. def test_zero_size_batch(self, f, args):
  551. message = "does not support zero-size batches."
  552. with pytest.raises(ValueError, match=message):
  553. f(*args)