test_blas.py 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037
  1. #
  2. # Created by: Pearu Peterson, April 2002
  3. #
  4. import math
  5. import pytest
  6. import numpy as np
  7. from numpy.testing import (assert_equal, assert_almost_equal,
  8. assert_array_almost_equal, assert_allclose)
  9. from pytest import raises as assert_raises
  10. from numpy import (arange, triu, tril, zeros, tril_indices, ones,
  11. diag, append, eye, nonzero)
  12. import scipy
  13. from scipy.linalg import _fblas as fblas, get_blas_funcs, toeplitz, solve
  14. try:
  15. from scipy.linalg import _cblas as cblas
  16. except ImportError:
  17. cblas = None
  18. REAL_DTYPES = [np.float32, np.float64]
  19. COMPLEX_DTYPES = [np.complex64, np.complex128]
  20. DTYPES = REAL_DTYPES + COMPLEX_DTYPES
  21. def test_get_blas_funcs():
  22. # check that it returns Fortran code for arrays that are
  23. # fortran-ordered
  24. f1, f2, f3 = get_blas_funcs(
  25. ('axpy', 'axpy', 'axpy'),
  26. (np.empty((2, 2), dtype=np.complex64, order='F'),
  27. np.empty((2, 2), dtype=np.complex128, order='C'))
  28. )
  29. # get_blas_funcs will choose libraries depending on most generic
  30. # array
  31. assert_equal(f1.typecode, 'z')
  32. assert_equal(f2.typecode, 'z')
  33. if cblas is not None:
  34. assert_equal(f1.module_name, 'cblas')
  35. assert_equal(f2.module_name, 'cblas')
  36. # check defaults.
  37. f1 = get_blas_funcs('rotg')
  38. assert_equal(f1.typecode, 'd')
  39. # check also dtype interface
  40. f1 = get_blas_funcs('gemm', dtype=np.complex64)
  41. assert_equal(f1.typecode, 'c')
  42. f1 = get_blas_funcs('gemm', dtype='F')
  43. assert_equal(f1.typecode, 'c')
  44. # extended precision complex
  45. f1 = get_blas_funcs('gemm', dtype=np.clongdouble)
  46. assert_equal(f1.typecode, 'z')
  47. # check safe complex upcasting
  48. f1 = get_blas_funcs('axpy',
  49. (np.empty((2, 2), dtype=np.float64),
  50. np.empty((2, 2), dtype=np.complex64))
  51. )
  52. assert_equal(f1.typecode, 'z')
  53. def test_get_blas_funcs_alias():
  54. # check alias for get_blas_funcs
  55. f, g = get_blas_funcs(('nrm2', 'dot'), dtype=np.complex64)
  56. assert f.typecode == 'c'
  57. assert g.typecode == 'c'
  58. f, g, h = get_blas_funcs(('dot', 'dotc', 'dotu'), dtype=np.float64)
  59. assert f is g
  60. assert f is h
  61. def parametrize_blas(mod, func_name, prefixes):
  62. if mod is None:
  63. return pytest.mark.skip(reason="cblas not available")
  64. params = []
  65. for prefix in prefixes:
  66. if 'z' in prefix:
  67. dtype = np.complex128
  68. elif 'c' in prefix:
  69. dtype = np.complex64
  70. elif 'd' in prefix:
  71. dtype = np.float64
  72. else:
  73. assert 's' in prefix
  74. dtype = np.float32
  75. f = getattr(mod, prefix + func_name)
  76. params.append(pytest.param(f, dtype, id=prefix + func_name))
  77. return pytest.mark.parametrize("f,dtype", params)
  78. class TestCBLAS1Simple:
  79. @parametrize_blas(cblas, "axpy", "sdcz")
  80. def test_axpy(self, f, dtype):
  81. assert_array_almost_equal(f([1, 2, 3], [2, -1, 3], a=5),
  82. [7, 9, 18])
  83. if dtype in COMPLEX_DTYPES:
  84. assert_array_almost_equal(f([1, 2j, 3], [2, -1, 3], a=5),
  85. [7, 10j-1, 18])
  86. class TestFBLAS1Simple:
  87. @parametrize_blas(fblas, "axpy", "sdcz")
  88. def test_axpy(self, f, dtype):
  89. assert_array_almost_equal(f([1, 2, 3], [2, -1, 3], a=5),
  90. [7, 9, 18])
  91. if dtype in COMPLEX_DTYPES:
  92. assert_array_almost_equal(f([1, 2j, 3], [2, -1, 3], a=5),
  93. [7, 10j-1, 18])
  94. @parametrize_blas(fblas, "copy", "sdcz")
  95. def test_copy(self, f, dtype):
  96. assert_array_almost_equal(f([3, 4, 5], [8]*3), [3, 4, 5])
  97. if dtype in COMPLEX_DTYPES:
  98. assert_array_almost_equal(f([3, 4j, 5+3j], [8]*3), [3, 4j, 5+3j])
  99. @parametrize_blas(fblas, "asum", ["s", "d", "sc", "dz"])
  100. def test_asum(self, f, dtype):
  101. assert_almost_equal(f([3, -4, 5]), 12)
  102. if dtype in COMPLEX_DTYPES:
  103. assert_almost_equal(f([3j, -4, 3-4j]), 14)
  104. @parametrize_blas(fblas, "dot", "sd")
  105. def test_dot(self, f, dtype):
  106. assert_almost_equal(f([3, -4, 5], [2, 5, 1]), -9)
  107. @parametrize_blas(fblas, "dotu", "cz")
  108. def test_dotu(self, f, dtype):
  109. assert_almost_equal(f([3j, -4, 3-4j], [2, 3, 1]), -9+2j)
  110. @parametrize_blas(fblas, "dotc", "cz")
  111. def test_dotc(self, f, dtype):
  112. assert_almost_equal(f([3j, -4, 3-4j], [2, 3j, 1]), 3-14j)
  113. @parametrize_blas(fblas, "nrm2", ["s", "d", "sc", "dz"])
  114. def test_nrm2(self, f, dtype):
  115. assert_almost_equal(f([3, -4, 5]), math.sqrt(50))
  116. if dtype in COMPLEX_DTYPES:
  117. assert_almost_equal(f([3j, -4, 3-4j]), math.sqrt(50))
  118. @parametrize_blas(fblas, "scal", ["s", "d", "cs", "zd"])
  119. def test_scal(self, f, dtype):
  120. assert_array_almost_equal(f(2, [3, -4, 5]), [6, -8, 10])
  121. if dtype in COMPLEX_DTYPES:
  122. assert_array_almost_equal(f(3, [3j, -4, 3-4j]), [9j, -12, 9-12j])
  123. @parametrize_blas(fblas, "swap", "sdcz")
  124. def test_swap(self, f, dtype):
  125. x, y = [2, 3, 1], [-2, 3, 7]
  126. x1, y1 = f(x, y)
  127. assert_array_almost_equal(x1, y)
  128. assert_array_almost_equal(y1, x)
  129. if dtype in COMPLEX_DTYPES:
  130. x, y = [2, 3j, 1], [-2, 3, 7-3j]
  131. x1, y1 = f(x, y)
  132. assert_array_almost_equal(x1, y)
  133. assert_array_almost_equal(y1, x)
  134. @parametrize_blas(fblas, "amax", ["is", "id", "ic", "iz"])
  135. def test_amax(self, f, dtype):
  136. assert_equal(f([-2, 4, 3]), 1)
  137. if dtype in COMPLEX_DTYPES:
  138. assert_equal(f([-5, 4+3j, 6]), 1)
  139. # XXX: need tests for rot,rotm,rotg,rotmg
  140. class TestFBLAS2Simple:
  141. @parametrize_blas(fblas, "gemv", "sdcz")
  142. def test_gemv(self, f, dtype):
  143. assert_array_almost_equal(f(3, [[3]], [-4]), [-36])
  144. assert_array_almost_equal(f(3, [[3]], [-4], 3, [5]), [-21])
  145. if dtype in COMPLEX_DTYPES:
  146. assert_array_almost_equal(f(3j, [[3-4j]], [-4]), [-48-36j])
  147. assert_array_almost_equal(f(3j, [[3-4j]], [-4], 3, [5j]),
  148. [-48-21j])
  149. @parametrize_blas(fblas, "ger", "sd")
  150. def test_ger(self, f, dtype):
  151. assert_array_almost_equal(f(1, [1, 2], [3, 4]), [[3, 4], [6, 8]])
  152. assert_array_almost_equal(f(2, [1, 2, 3], [3, 4]),
  153. [[6, 8], [12, 16], [18, 24]])
  154. assert_array_almost_equal(f(1, [1, 2], [3, 4],
  155. a=[[1, 2], [3, 4]]), [[4, 6], [9, 12]])
  156. if dtype in COMPLEX_DTYPES:
  157. assert_array_almost_equal(f(1, [1j, 2], [3, 4]),
  158. [[3j, 4j], [6, 8]])
  159. assert_array_almost_equal(f(2, [1j, 2j, 3j], [3j, 4j]),
  160. [[6, 8], [12, 16], [18, 24]])
  161. @parametrize_blas(fblas, "geru", "cz")
  162. def test_geru(self, f, dtype):
  163. assert_array_almost_equal(f(1, [1j, 2], [3, 4]),
  164. [[3j, 4j], [6, 8]])
  165. assert_array_almost_equal(f(-2, [1j, 2j, 3j], [3j, 4j]),
  166. [[6, 8], [12, 16], [18, 24]])
  167. @parametrize_blas(fblas, "gerc", "cz")
  168. def test_gerc(self, f, dtype):
  169. assert_array_almost_equal(f(1, [1j, 2], [3, 4]),
  170. [[3j, 4j], [6, 8]])
  171. assert_array_almost_equal(f(2, [1j, 2j, 3j], [3j, 4j]),
  172. [[6, 8], [12, 16], [18, 24]])
  173. @parametrize_blas(fblas, "syr", "sdcz")
  174. def test_syr(self, f, dtype):
  175. x = np.arange(1, 5, dtype='d')
  176. resx = np.triu(x[:, np.newaxis] * x)
  177. resx_reverse = np.triu(x[::-1, np.newaxis] * x[::-1])
  178. y = np.linspace(0, 8.5, 17, endpoint=False)
  179. z = np.arange(1, 9, dtype='d').view('D')
  180. resz = np.triu(z[:, np.newaxis] * z)
  181. resz_reverse = np.triu(z[::-1, np.newaxis] * z[::-1])
  182. w = np.c_[np.zeros(4), z, np.zeros(4)].ravel()
  183. rtol = np.finfo(dtype).eps
  184. assert_allclose(f(1.0, x), resx, rtol=rtol)
  185. assert_allclose(f(1.0, x, lower=True), resx.T, rtol=rtol)
  186. assert_allclose(f(1.0, y, incx=2, offx=2, n=4), resx, rtol=rtol)
  187. # negative increments imply reversed vectors in blas
  188. assert_allclose(f(1.0, y, incx=-2, offx=2, n=4),
  189. resx_reverse, rtol=rtol)
  190. if dtype in COMPLEX_DTYPES:
  191. assert_allclose(f(1.0, z), resz, rtol=rtol)
  192. assert_allclose(f(1.0, z, lower=True), resz.T, rtol=rtol)
  193. assert_allclose(f(1.0, w, incx=3, offx=1, n=4), resz, rtol=rtol)
  194. # negative increments imply reversed vectors in blas
  195. assert_allclose(f(1.0, w, incx=-3, offx=1, n=4),
  196. resz_reverse, rtol=rtol)
  197. a = np.zeros((4, 4), dtype, 'F')
  198. b = f(1.0, z, a=a, overwrite_a=True)
  199. assert_allclose(a, resz, rtol=rtol)
  200. b = f(2.0, z, a=a)
  201. assert a is not b
  202. assert_allclose(b, 3*resz, rtol=rtol)
  203. else:
  204. a = np.zeros((4, 4), dtype, 'F')
  205. b = f(1.0, x, a=a, overwrite_a=True)
  206. assert_allclose(a, resx, rtol=rtol)
  207. b = f(2.0, x, a=a)
  208. assert a is not b
  209. assert_allclose(b, 3*resx, rtol=rtol)
  210. assert_raises(Exception, f, 1.0, x, incx=0)
  211. assert_raises(Exception, f, 1.0, x, offx=5)
  212. assert_raises(Exception, f, 1.0, x, offx=-2)
  213. assert_raises(Exception, f, 1.0, x, n=-2)
  214. assert_raises(Exception, f, 1.0, x, n=5)
  215. assert_raises(Exception, f, 1.0, x, lower=2)
  216. assert_raises(Exception, f, 1.0, x, a=np.zeros((2, 2), 'd', 'F'))
  217. @parametrize_blas(fblas, "her", "cz")
  218. def test_her(self, f, dtype):
  219. x = np.arange(1, 5, dtype='d')
  220. z = np.arange(1, 9, dtype='d').view('D')
  221. rehz = np.triu(z[:, np.newaxis] * z.conj())
  222. rehz_reverse = np.triu(z[::-1, np.newaxis] * z[::-1].conj())
  223. w = np.c_[np.zeros(4), z, np.zeros(4)].ravel()
  224. rtol = np.finfo(dtype).eps
  225. assert_allclose(f(1.0, z), rehz, rtol=rtol)
  226. assert_allclose(f(1.0, z, lower=True), rehz.T.conj(), rtol=rtol)
  227. assert_allclose(f(1.0, w, incx=3, offx=1, n=4), rehz, rtol=rtol)
  228. # negative increments imply reversed vectors in blas
  229. assert_allclose(f(1.0, w, incx=-3, offx=1, n=4),
  230. rehz_reverse, rtol=rtol)
  231. a = np.zeros((4, 4), dtype, 'F')
  232. b = f(1.0, z, a=a, overwrite_a=True)
  233. assert_allclose(a, rehz, rtol=rtol)
  234. b = f(2.0, z, a=a)
  235. assert a is not b
  236. assert_allclose(b, 3*rehz, rtol=rtol)
  237. assert_raises(Exception, f, 1.0, x, incx=0)
  238. assert_raises(Exception, f, 1.0, x, offx=5)
  239. assert_raises(Exception, f, 1.0, x, offx=-2)
  240. assert_raises(Exception, f, 1.0, x, n=-2)
  241. assert_raises(Exception, f, 1.0, x, n=5)
  242. assert_raises(Exception, f, 1.0, x, lower=2)
  243. assert_raises(Exception, f, 1.0, x, a=np.zeros((2, 2), 'd', 'F'))
  244. @parametrize_blas(fblas, "syr2", "sd")
  245. def test_syr2(self, f, dtype):
  246. x = np.arange(1, 5, dtype='d')
  247. y = np.arange(5, 9, dtype='d')
  248. resxy = np.triu(x[:, np.newaxis] * y + y[:, np.newaxis] * x)
  249. resxy_reverse = np.triu(x[::-1, np.newaxis] * y[::-1]
  250. + y[::-1, np.newaxis] * x[::-1])
  251. q = np.linspace(0, 8.5, 17, endpoint=False)
  252. rtol = np.finfo(dtype).eps
  253. assert_allclose(f(1.0, x, y), resxy, rtol=rtol)
  254. assert_allclose(f(1.0, x, y, n=3), resxy[:3, :3], rtol=rtol)
  255. assert_allclose(f(1.0, x, y, lower=True), resxy.T, rtol=rtol)
  256. assert_allclose(f(1.0, q, q, incx=2, offx=2, incy=2, offy=10),
  257. resxy, rtol=rtol)
  258. assert_allclose(f(1.0, q, q, incx=2, offx=2, incy=2, offy=10, n=3),
  259. resxy[:3, :3], rtol=rtol)
  260. # negative increments imply reversed vectors in blas
  261. assert_allclose(f(1.0, q, q, incx=-2, offx=2, incy=-2, offy=10),
  262. resxy_reverse, rtol=rtol)
  263. a = np.zeros((4, 4), dtype, 'F')
  264. b = f(1.0, x, y, a=a, overwrite_a=True)
  265. assert_allclose(a, resxy, rtol=rtol)
  266. b = f(2.0, x, y, a=a)
  267. assert a is not b
  268. assert_allclose(b, 3*resxy, rtol=rtol)
  269. assert_raises(Exception, f, 1.0, x, y, incx=0)
  270. assert_raises(Exception, f, 1.0, x, y, offx=5)
  271. assert_raises(Exception, f, 1.0, x, y, offx=-2)
  272. assert_raises(Exception, f, 1.0, x, y, incy=0)
  273. assert_raises(Exception, f, 1.0, x, y, offy=5)
  274. assert_raises(Exception, f, 1.0, x, y, offy=-2)
  275. assert_raises(Exception, f, 1.0, x, y, n=-2)
  276. assert_raises(Exception, f, 1.0, x, y, n=5)
  277. assert_raises(Exception, f, 1.0, x, y, lower=2)
  278. assert_raises(Exception, f, 1.0, x, y, a=np.zeros((2, 2), 'd', 'F'))
  279. @parametrize_blas(fblas, "her2", "cz")
  280. def test_her2(self, f, dtype):
  281. x = np.arange(1, 9, dtype='d').view('D')
  282. y = np.arange(9, 17, dtype='d').view('D')
  283. resxy = x[:, np.newaxis] * y.conj() + y[:, np.newaxis] * x.conj()
  284. resxy = np.triu(resxy)
  285. resxy_reverse = x[::-1, np.newaxis] * y[::-1].conj()
  286. resxy_reverse += y[::-1, np.newaxis] * x[::-1].conj()
  287. resxy_reverse = np.triu(resxy_reverse)
  288. u = np.c_[np.zeros(4), x, np.zeros(4)].ravel()
  289. v = np.c_[np.zeros(4), y, np.zeros(4)].ravel()
  290. rtol = np.finfo(dtype).eps
  291. assert_allclose(f(1.0, x, y), resxy, rtol=rtol)
  292. assert_allclose(f(1.0, x, y, n=3), resxy[:3, :3], rtol=rtol)
  293. assert_allclose(f(1.0, x, y, lower=True), resxy.T.conj(),
  294. rtol=rtol)
  295. assert_allclose(f(1.0, u, v, incx=3, offx=1, incy=3, offy=1),
  296. resxy, rtol=rtol)
  297. assert_allclose(f(1.0, u, v, incx=3, offx=1, incy=3, offy=1, n=3),
  298. resxy[:3, :3], rtol=rtol)
  299. # negative increments imply reversed vectors in blas
  300. assert_allclose(f(1.0, u, v, incx=-3, offx=1, incy=-3, offy=1),
  301. resxy_reverse, rtol=rtol)
  302. a = np.zeros((4, 4), dtype, 'F')
  303. b = f(1.0, x, y, a=a, overwrite_a=True)
  304. assert_allclose(a, resxy, rtol=rtol)
  305. b = f(2.0, x, y, a=a)
  306. assert a is not b
  307. assert_allclose(b, 3*resxy, rtol=rtol)
  308. assert_raises(Exception, f, 1.0, x, y, incx=0)
  309. assert_raises(Exception, f, 1.0, x, y, offx=5)
  310. assert_raises(Exception, f, 1.0, x, y, offx=-2)
  311. assert_raises(Exception, f, 1.0, x, y, incy=0)
  312. assert_raises(Exception, f, 1.0, x, y, offy=5)
  313. assert_raises(Exception, f, 1.0, x, y, offy=-2)
  314. assert_raises(Exception, f, 1.0, x, y, n=-2)
  315. assert_raises(Exception, f, 1.0, x, y, n=5)
  316. assert_raises(Exception, f, 1.0, x, y, lower=2)
  317. assert_raises(Exception, f, 1.0, x, y,
  318. a=np.zeros((2, 2), 'd', 'F'))
  319. @pytest.mark.parametrize("dtype", DTYPES)
  320. def test_gbmv(self, dtype):
  321. rng = np.random.default_rng(1234)
  322. n = 7
  323. m = 5
  324. kl = 1
  325. ku = 2
  326. # fake a banded matrix via toeplitz
  327. A = toeplitz(append(rng.random(kl+1), zeros(m-kl-1)),
  328. append(rng.random(ku+1), zeros(n-ku-1)))
  329. A = A.astype(dtype)
  330. Ab = zeros((kl+ku+1, n), dtype=dtype)
  331. # Form the banded storage
  332. Ab[2, :5] = A[0, 0] # diag
  333. Ab[1, 1:6] = A[0, 1] # sup1
  334. Ab[0, 2:7] = A[0, 2] # sup2
  335. Ab[3, :4] = A[1, 0] # sub1
  336. x = rng.random(n).astype(dtype)
  337. y = rng.random(m).astype(dtype)
  338. alpha, beta = dtype(3), dtype(-5)
  339. func, = get_blas_funcs(('gbmv',), dtype=dtype)
  340. y1 = func(m=m, n=n, ku=ku, kl=kl, alpha=alpha, a=Ab,
  341. x=x, y=y, beta=beta)
  342. y2 = alpha * A.dot(x) + beta * y
  343. assert_array_almost_equal(y1, y2)
  344. y1 = func(m=m, n=n, ku=ku, kl=kl, alpha=alpha, a=Ab,
  345. x=y, y=x, beta=beta, trans=1)
  346. y2 = alpha * A.T.dot(y) + beta * x
  347. assert_array_almost_equal(y1, y2)
  348. @pytest.mark.parametrize("dtype", DTYPES)
  349. def test_sbmv_hbmv(self, dtype):
  350. rng = np.random.default_rng(1234)
  351. n = 6
  352. k = 2
  353. A = zeros((n, n), dtype=dtype)
  354. Ab = zeros((k+1, n), dtype=dtype)
  355. # Form the array and its packed banded storage
  356. A[arange(n), arange(n)] = rng.random(n)
  357. for ind2 in range(1, k+1):
  358. temp = rng.random(n-ind2)
  359. A[arange(n-ind2), arange(ind2, n)] = temp
  360. Ab[-1-ind2, ind2:] = temp
  361. A = A.astype(dtype)
  362. if dtype in COMPLEX_DTYPES:
  363. A += A.conj().T
  364. func, = get_blas_funcs(('hbmv',), dtype=dtype)
  365. else:
  366. A += A.T
  367. func, = get_blas_funcs(('sbmv',), dtype=dtype)
  368. Ab[-1, :] = diag(A)
  369. x = rng.random(n).astype(dtype)
  370. y = rng.random(n).astype(dtype)
  371. alpha, beta = dtype(1.25), dtype(3)
  372. y1 = func(k=k, alpha=alpha, a=Ab, x=x, y=y, beta=beta)
  373. y2 = alpha * A.dot(x) + beta * y
  374. assert_array_almost_equal(y1, y2)
  375. @pytest.mark.parametrize("fname,dtype", [
  376. *[('spmv', dtype) for dtype in REAL_DTYPES + COMPLEX_DTYPES],
  377. *[('hpmv', dtype) for dtype in COMPLEX_DTYPES],
  378. ])
  379. def test_spmv_hpmv(self, fname, dtype):
  380. rng = np.random.default_rng(1234)
  381. n = 3
  382. A = rng.random((n, n)).astype(dtype)
  383. if dtype in COMPLEX_DTYPES:
  384. A += rng.random((n, n))*1j
  385. A += A.T if fname == 'spmv' else A.conj().T
  386. c, r = tril_indices(n)
  387. Ap = A[r, c]
  388. x = rng.random(n).astype(dtype)
  389. y = rng.random(n).astype(dtype)
  390. xlong = arange(2*n).astype(dtype)
  391. ylong = ones(2*n).astype(dtype)
  392. alpha, beta = dtype(1.25), dtype(2)
  393. func, = get_blas_funcs((fname,), dtype=dtype)
  394. y1 = func(n=n, alpha=alpha, ap=Ap, x=x, y=y, beta=beta)
  395. y2 = alpha * A.dot(x) + beta * y
  396. assert_array_almost_equal(y1, y2)
  397. # Test inc and offsets
  398. y1 = func(n=n-1, alpha=alpha, beta=beta, x=xlong, y=ylong, ap=Ap,
  399. incx=2, incy=2, offx=n, offy=n)
  400. y2 = (alpha * A[:-1, :-1]).dot(xlong[3::2]) + beta * ylong[3::2]
  401. assert_array_almost_equal(y1[3::2], y2)
  402. assert_almost_equal(y1[4], ylong[4])
  403. @pytest.mark.parametrize("fname,dtype", [
  404. *[('spr', dtype) for dtype in REAL_DTYPES + COMPLEX_DTYPES],
  405. *[('hpr', dtype) for dtype in COMPLEX_DTYPES],
  406. ])
  407. def test_spr_hpr(self, fname, dtype):
  408. rng = np.random.default_rng(1234)
  409. n = 3
  410. A = rng.random((n, n)).astype(dtype)
  411. if dtype in COMPLEX_DTYPES:
  412. A += rng.random((n, n))*1j
  413. A += A.T if fname == 'spr' else A.conj().T
  414. c, r = tril_indices(n)
  415. Ap = A[r, c]
  416. x = rng.random(n).astype(dtype)
  417. alpha = np.finfo(dtype).dtype.type(2.5)
  418. if fname == 'hpr':
  419. func, = get_blas_funcs(('hpr',), dtype=dtype)
  420. y2 = alpha * x[:, None].dot(x[None, :].conj()) + A
  421. else:
  422. func, = get_blas_funcs(('spr',), dtype=dtype)
  423. y2 = alpha * x[:, None].dot(x[None, :]) + A
  424. y1 = func(n=n, alpha=alpha, ap=Ap, x=x)
  425. y1f = zeros((3, 3), dtype=dtype)
  426. y1f[r, c] = y1
  427. y1f[c, r] = y1.conj() if fname == 'hpr' else y1
  428. assert_array_almost_equal(y1f, y2)
  429. @pytest.mark.parametrize("dtype", DTYPES)
  430. def test_spr2_hpr2(self, dtype):
  431. rng = np.random.default_rng(1234)
  432. n = 3
  433. A = rng.random((n, n)).astype(dtype)
  434. if dtype in COMPLEX_DTYPES:
  435. A += rng.random((n, n))*1j
  436. A += A.conj().T
  437. func, = get_blas_funcs(('hpr2',), dtype=dtype)
  438. else:
  439. A += A.T
  440. func, = get_blas_funcs(('spr2',), dtype=dtype)
  441. c, r = tril_indices(n)
  442. Ap = A[r, c]
  443. x = rng.random(n).astype(dtype)
  444. y = rng.random(n).astype(dtype)
  445. alpha = dtype(2)
  446. u = alpha.conj() * x[:, None].dot(y[None, :].conj())
  447. y2 = A + u + u.conj().T
  448. y1 = func(n=n, alpha=alpha, x=x, y=y, ap=Ap)
  449. y1f = zeros((3, 3), dtype=dtype)
  450. y1f[r, c] = y1
  451. y1f[[1, 2, 2], [0, 0, 1]] = y1[[1, 3, 4]].conj()
  452. assert_array_almost_equal(y1f, y2)
  453. @pytest.mark.parametrize("dtype", DTYPES)
  454. def test_tbmv(self, dtype):
  455. rng = np.random.default_rng(1234)
  456. n = 10
  457. k = 3
  458. x = rng.random(n).astype(dtype)
  459. A = zeros((n, n), dtype=dtype)
  460. # Banded upper triangular array
  461. for sup in range(k+1):
  462. A[arange(n-sup), arange(sup, n)] = rng.random(n-sup)
  463. # Add complex parts for c,z
  464. if dtype in COMPLEX_DTYPES:
  465. A[nonzero(A)] += 1j * rng.random((k+1)*n-(k*(k+1)//2)).astype(dtype)
  466. # Form the banded storage
  467. Ab = zeros((k+1, n), dtype=dtype)
  468. for row in range(k+1):
  469. Ab[-row-1, row:] = diag(A, k=row)
  470. func, = get_blas_funcs(('tbmv',), dtype=dtype)
  471. y1 = func(k=k, a=Ab, x=x)
  472. y2 = A.dot(x)
  473. assert_array_almost_equal(y1, y2)
  474. y1 = func(k=k, a=Ab, x=x, diag=1)
  475. A[arange(n), arange(n)] = dtype(1)
  476. y2 = A.dot(x)
  477. assert_array_almost_equal(y1, y2)
  478. y1 = func(k=k, a=Ab, x=x, diag=1, trans=1)
  479. y2 = A.T.dot(x)
  480. assert_array_almost_equal(y1, y2)
  481. y1 = func(k=k, a=Ab, x=x, diag=1, trans=2)
  482. y2 = A.conj().T.dot(x)
  483. assert_array_almost_equal(y1, y2)
  484. @pytest.mark.parametrize("dtype", DTYPES)
  485. def test_tbsv(self, dtype):
  486. rng = np.random.default_rng(12345)
  487. n = 6
  488. k = 3
  489. x = rng.random(n).astype(dtype)
  490. A = zeros((n, n), dtype=dtype)
  491. # Banded upper triangular array
  492. for sup in range(k+1):
  493. A[arange(n-sup), arange(sup, n)] = rng.random(n-sup)
  494. # Add complex parts for c,z
  495. if dtype in COMPLEX_DTYPES:
  496. A[nonzero(A)] += 1j * rng.random((k+1)*n-(k*(k+1)//2)).astype(dtype)
  497. # Form the banded storage
  498. Ab = zeros((k+1, n), dtype=dtype)
  499. for row in range(k+1):
  500. Ab[-row-1, row:] = diag(A, k=row)
  501. func, = get_blas_funcs(('tbsv',), dtype=dtype)
  502. y1 = func(k=k, a=Ab, x=x)
  503. y2 = solve(A, x)
  504. assert_array_almost_equal(y1, y2)
  505. y1 = func(k=k, a=Ab, x=x, diag=1)
  506. A[arange(n), arange(n)] = dtype(1)
  507. y2 = solve(A, x)
  508. assert_array_almost_equal(y1, y2)
  509. y1 = func(k=k, a=Ab, x=x, diag=1, trans=1)
  510. y2 = solve(A.T, x)
  511. assert_array_almost_equal(y1, y2)
  512. y1 = func(k=k, a=Ab, x=x, diag=1, trans=2)
  513. y2 = solve(A.conj().T, x)
  514. assert_array_almost_equal(y1, y2)
  515. @pytest.mark.parametrize("dtype", DTYPES)
  516. def test_tpmv(self, dtype):
  517. rng = np.random.default_rng(1234)
  518. n = 10
  519. x = rng.random(n).astype(dtype)
  520. # Upper triangular array
  521. if dtype in COMPLEX_DTYPES:
  522. A = triu(rng.random((n, n)) + rng.random((n, n))*1j)
  523. else:
  524. A = triu(rng.random((n, n)))
  525. # Form the packed storage
  526. c, r = tril_indices(n)
  527. Ap = A[r, c]
  528. func, = get_blas_funcs(('tpmv',), dtype=dtype)
  529. y1 = func(n=n, ap=Ap, x=x)
  530. y2 = A.dot(x)
  531. assert_array_almost_equal(y1, y2)
  532. y1 = func(n=n, ap=Ap, x=x, diag=1)
  533. A[arange(n), arange(n)] = dtype(1)
  534. y2 = A.dot(x)
  535. assert_array_almost_equal(y1, y2)
  536. y1 = func(n=n, ap=Ap, x=x, diag=1, trans=1)
  537. y2 = A.T.dot(x)
  538. assert_array_almost_equal(y1, y2)
  539. y1 = func(n=n, ap=Ap, x=x, diag=1, trans=2)
  540. y2 = A.conj().T.dot(x)
  541. assert_array_almost_equal(y1, y2)
  542. @pytest.mark.parametrize("dtype", DTYPES)
  543. def test_tpsv(self, dtype):
  544. rng = np.random.default_rng(1234)
  545. n = 10
  546. x = rng.random(n).astype(dtype)
  547. # Upper triangular array
  548. if dtype in COMPLEX_DTYPES:
  549. A = triu(rng.random((n, n)) + rng.random((n, n))*1j)
  550. else:
  551. A = triu(rng.random((n, n)))
  552. A += eye(n)
  553. # Form the packed storage
  554. c, r = tril_indices(n)
  555. Ap = A[r, c]
  556. func, = get_blas_funcs(('tpsv',), dtype=dtype)
  557. y1 = func(n=n, ap=Ap, x=x)
  558. y2 = solve(A, x)
  559. assert_array_almost_equal(y1, y2)
  560. y1 = func(n=n, ap=Ap, x=x, diag=1)
  561. A[arange(n), arange(n)] = dtype(1)
  562. y2 = solve(A, x)
  563. assert_array_almost_equal(y1, y2)
  564. y1 = func(n=n, ap=Ap, x=x, diag=1, trans=1)
  565. y2 = solve(A.T, x)
  566. assert_array_almost_equal(y1, y2)
  567. y1 = func(n=n, ap=Ap, x=x, diag=1, trans=2)
  568. y2 = solve(A.conj().T, x)
  569. assert_array_almost_equal(y1, y2)
  570. @pytest.mark.parametrize("dtype", DTYPES)
  571. def test_trmv(self, dtype):
  572. rng = np.random.default_rng(1234)
  573. n = 3
  574. A = (rng.random((n, n))+eye(n)).astype(dtype)
  575. x = rng.random(3).astype(dtype)
  576. func, = get_blas_funcs(('trmv',), dtype=dtype)
  577. y1 = func(a=A, x=x)
  578. y2 = triu(A).dot(x)
  579. assert_array_almost_equal(y1, y2)
  580. y1 = func(a=A, x=x, diag=1)
  581. A[arange(n), arange(n)] = dtype(1)
  582. y2 = triu(A).dot(x)
  583. assert_array_almost_equal(y1, y2)
  584. y1 = func(a=A, x=x, diag=1, trans=1)
  585. y2 = triu(A).T.dot(x)
  586. assert_array_almost_equal(y1, y2)
  587. y1 = func(a=A, x=x, diag=1, trans=2)
  588. y2 = triu(A).conj().T.dot(x)
  589. assert_array_almost_equal(y1, y2)
  590. @pytest.mark.parametrize("dtype", DTYPES)
  591. def test_trsv(self, dtype):
  592. rng = np.random.default_rng(1234)
  593. n = 15
  594. A = (rng.random((n, n))+eye(n)).astype(dtype)
  595. x = rng.random(n).astype(dtype)
  596. func, = get_blas_funcs(('trsv',), dtype=dtype)
  597. y1 = func(a=A, x=x)
  598. y2 = solve(triu(A), x)
  599. assert_array_almost_equal(y1, y2)
  600. y1 = func(a=A, x=x, lower=1)
  601. y2 = solve(tril(A), x)
  602. assert_array_almost_equal(y1, y2)
  603. y1 = func(a=A, x=x, diag=1)
  604. A[arange(n), arange(n)] = dtype(1)
  605. y2 = solve(triu(A), x)
  606. assert_array_almost_equal(y1, y2)
  607. y1 = func(a=A, x=x, diag=1, trans=1)
  608. y2 = solve(triu(A).T, x)
  609. assert_array_almost_equal(y1, y2)
  610. y1 = func(a=A, x=x, diag=1, trans=2)
  611. y2 = solve(triu(A).conj().T, x)
  612. assert_array_almost_equal(y1, y2)
  613. class TestFBLAS3Simple:
  614. @parametrize_blas(fblas, "gemm", "sdcz")
  615. def test_gemm(self, f, dtype):
  616. assert_array_almost_equal(f(3, [3], [-4]), [[-36]])
  617. assert_array_almost_equal(f(3, [3], [-4], 3, [5]), [-21])
  618. if dtype in COMPLEX_DTYPES:
  619. assert_array_almost_equal(f(3j, [3-4j], [-4]), [[-48-36j]])
  620. assert_array_almost_equal(f(3j, [3-4j], [-4], 3, [5j]), [-48-21j])
  621. class TestBLAS3Symm:
  622. def setup_method(self):
  623. self.a = np.array([[1., 2.],
  624. [0., 1.]])
  625. self.b = np.array([[1., 0., 3.],
  626. [0., -1., 2.]])
  627. self.c = np.ones((2, 3))
  628. self.t = np.array([[2., -1., 8.],
  629. [3., 0., 9.]])
  630. @parametrize_blas(fblas, "symm", "sdcz")
  631. def test_symm(self, f, dtype):
  632. res = f(a=self.a, b=self.b, c=self.c, alpha=1., beta=1.)
  633. assert_array_almost_equal(res, self.t)
  634. res = f(a=self.a.T, b=self.b, lower=1, c=self.c, alpha=1., beta=1.)
  635. assert_array_almost_equal(res, self.t)
  636. res = f(a=self.a, b=self.b.T, side=1, c=self.c.T,
  637. alpha=1., beta=1.)
  638. assert_array_almost_equal(res, self.t.T)
  639. @parametrize_blas(fblas, "symm", "sdcz")
  640. def test_symm_wrong_side(self, f, dtype):
  641. """`side=1` means C <- B*A, hence shapes of A and B are to be
  642. compatible. Otherwise, f2py exception is raised.
  643. """
  644. # FIXME narrow down to _fblas.error
  645. with pytest.raises(Exception):
  646. f(a=self.a, b=self.b, alpha=1, side=1)
  647. @parametrize_blas(fblas, "symm", "sdcz")
  648. def test_symm_wrong_uplo(self, f, dtype):
  649. """SYMM only considers the upper/lower part of A. Hence setting
  650. wrong value for `lower` (default is lower=0, meaning upper triangle)
  651. gives a wrong result.
  652. """
  653. res = f(a=self.a, b=self.b, c=self.c, alpha=1., beta=1.)
  654. assert np.allclose(res, self.t)
  655. res = f(a=self.a, b=self.b, lower=1, c=self.c, alpha=1., beta=1.)
  656. assert not np.allclose(res, self.t)
  657. class TestBLAS3Syrk:
  658. def setup_method(self):
  659. self.a = np.array([[1., 0.],
  660. [0., -2.],
  661. [2., 3.]])
  662. self.t = np.array([[1., 0., 2.],
  663. [0., 4., -6.],
  664. [2., -6., 13.]])
  665. self.tt = np.array([[5., 6.],
  666. [6., 13.]])
  667. @parametrize_blas(fblas, "syrk", "sdcz")
  668. def test_syrk(self, f, dtype):
  669. c = f(a=self.a, alpha=1.)
  670. assert_array_almost_equal(np.triu(c), np.triu(self.t))
  671. c = f(a=self.a, alpha=1., lower=1)
  672. assert_array_almost_equal(np.tril(c), np.tril(self.t))
  673. c0 = np.ones(self.t.shape)
  674. c = f(a=self.a, alpha=1., beta=1., c=c0)
  675. assert_array_almost_equal(np.triu(c), np.triu(self.t+c0))
  676. c = f(a=self.a, alpha=1., trans=1)
  677. assert_array_almost_equal(np.triu(c), np.triu(self.tt))
  678. # prints '0-th dimension must be fixed to 3 but got 5',
  679. # FIXME: suppress?
  680. @parametrize_blas(fblas, "syrk", "sdcz")
  681. def test_syrk_wrong_c(self, f, dtype):
  682. # FIXME narrow down to _fblas.error
  683. with pytest.raises(Exception):
  684. f(a=self.a, alpha=1., c=np.ones((5, 8)))
  685. # if C is supplied, it must have compatible dimensions
  686. class TestBLAS3Syr2k:
  687. def setup_method(self):
  688. self.a = np.array([[1., 0.],
  689. [0., -2.],
  690. [2., 3.]])
  691. self.b = np.array([[0., 1.],
  692. [1., 0.],
  693. [0, 1.]])
  694. self.t = np.array([[0., -1., 3.],
  695. [-1., 0., 0.],
  696. [3., 0., 6.]])
  697. self.tt = np.array([[0., 1.],
  698. [1., 6]])
  699. @parametrize_blas(fblas, "syr2k", "sdcz")
  700. def test_syr2k(self, f, dtype):
  701. c = f(a=self.a, b=self.b, alpha=1.)
  702. assert_array_almost_equal(np.triu(c), np.triu(self.t))
  703. c = f(a=self.a, b=self.b, alpha=1., lower=1)
  704. assert_array_almost_equal(np.tril(c), np.tril(self.t))
  705. c0 = np.ones(self.t.shape)
  706. c = f(a=self.a, b=self.b, alpha=1., beta=1., c=c0)
  707. assert_array_almost_equal(np.triu(c), np.triu(self.t+c0))
  708. c = f(a=self.a, b=self.b, alpha=1., trans=1)
  709. assert_array_almost_equal(np.triu(c), np.triu(self.tt))
  710. # prints '0-th dimension must be fixed to 3 but got 5', FIXME: suppress?
  711. @parametrize_blas(fblas, "syr2k", "sdcz")
  712. def test_syr2k_wrong_c(self, f, dtype):
  713. with pytest.raises(Exception):
  714. f(a=self.a, b=self.b, alpha=1., c=np.zeros((15, 8)))
  715. # if C is supplied, it must have compatible dimensions
  716. class TestSyHe:
  717. """Quick and simple tests for (zc)-symm, syrk, syr2k."""
  718. def setup_method(self):
  719. self.sigma_y = np.array([[0., -1.j],
  720. [1.j, 0.]])
  721. @parametrize_blas(fblas, "symm", "zc")
  722. def test_symm(self, f, dtype):
  723. # NB: a is symmetric w/upper diag of ONLY
  724. res = f(a=self.sigma_y, b=self.sigma_y, alpha=1.)
  725. assert_array_almost_equal(np.triu(res), np.diag([1, -1]))
  726. @parametrize_blas(fblas, "hemm", "zc")
  727. def test_hemm(self, f, dtype):
  728. # NB: a is hermitian w/upper diag of ONLY
  729. res = f(a=self.sigma_y, b=self.sigma_y, alpha=1.)
  730. assert_array_almost_equal(np.triu(res), np.diag([1, 1]))
  731. @parametrize_blas(fblas, "syrk", "zc")
  732. def test_syrk(self, f, dtype):
  733. res = f(a=self.sigma_y, alpha=1.)
  734. assert_array_almost_equal(np.triu(res), np.diag([-1, -1]))
  735. @parametrize_blas(fblas, "herk", "zc")
  736. def test_herk(self, f, dtype):
  737. res = f(a=self.sigma_y, alpha=1.)
  738. assert_array_almost_equal(np.triu(res), np.diag([1, 1]))
  739. @parametrize_blas(fblas, "syr2k", "zc")
  740. def test_syr2k_zr(self, f, dtype):
  741. res = f(a=self.sigma_y, b=self.sigma_y, alpha=1.)
  742. assert_array_almost_equal(np.triu(res), 2.*np.diag([-1, -1]))
  743. @parametrize_blas(fblas, "her2k", "zc")
  744. def test_her2k_zr(self, f, dtype):
  745. res = f(a=self.sigma_y, b=self.sigma_y, alpha=1.)
  746. assert_array_almost_equal(np.triu(res), 2.*np.diag([1, 1]))
  747. class TestTRMM:
  748. """Quick and simple tests for *trmm."""
  749. def setup_method(self):
  750. self.a = np.array([[1., 2., ],
  751. [-2., 1.]])
  752. self.b = np.array([[3., 4., -1.],
  753. [5., 6., -2.]])
  754. self.a2 = np.array([[1, 1, 2, 3],
  755. [0, 1, 4, 5],
  756. [0, 0, 1, 6],
  757. [0, 0, 0, 1]], order="f")
  758. self.b2 = np.array([[1, 4], [2, 5], [3, 6], [7, 8], [9, 10]],
  759. order="f")
  760. @pytest.mark.parametrize("dtype", DTYPES)
  761. def test_side(self, dtype):
  762. trmm = get_blas_funcs("trmm", dtype=dtype)
  763. # Provide large A array that works for side=1 but not 0 (see gh-10841)
  764. assert_raises(Exception, trmm, 1.0, self.a2, self.b2)
  765. res = trmm(1.0, self.a2.astype(dtype), self.b2.astype(dtype),
  766. side=1)
  767. k = self.b2.shape[1]
  768. assert_allclose(res, self.b2 @ self.a2[:k, :k], rtol=0.,
  769. atol=100*np.finfo(dtype).eps)
  770. @parametrize_blas(fblas, "trmm", "sdcz")
  771. def test_ab(self, f, dtype):
  772. result = f(1., self.a, self.b)
  773. # default a is upper triangular
  774. expected = np.array([[13., 16., -5.],
  775. [ 5., 6., -2.]])
  776. assert_array_almost_equal(result, expected)
  777. @parametrize_blas(fblas, "trmm", "sdcz")
  778. def test_ab_lower(self, f, dtype):
  779. result = f(1., self.a, self.b, lower=True)
  780. expected = np.array([[ 3., 4., -1.],
  781. [-1., -2., 0.]]) # now a is lower triangular
  782. assert_array_almost_equal(result, expected)
  783. @parametrize_blas(fblas, "trmm", "sdcz")
  784. def test_b_overwrites(self, f, dtype):
  785. # BLAS *trmm modifies B argument in-place.
  786. # Here the default is to copy, but this can be overridden
  787. b = self.b.astype(dtype)
  788. for overwr in [True, False]:
  789. bcopy = b.copy()
  790. result = f(1., self.a, bcopy, overwrite_b=overwr)
  791. # C-contiguous arrays are copied
  792. assert not bcopy.flags.f_contiguous
  793. assert not np.may_share_memory(bcopy, result)
  794. assert_equal(bcopy, b)
  795. bcopy = np.asfortranarray(b.copy()) # or just transpose it
  796. result = f(1., self.a, bcopy, overwrite_b=True)
  797. assert bcopy.flags.f_contiguous
  798. assert np.may_share_memory(bcopy, result)
  799. assert_array_almost_equal(bcopy, result)
  800. @pytest.mark.parametrize("dtype", DTYPES)
  801. def test_trsm(dtype):
  802. rng = np.random.default_rng(1234)
  803. tol = np.finfo(dtype).eps*1000
  804. func, = get_blas_funcs(('trsm',), dtype=dtype)
  805. # Test protection against size mismatches
  806. A = rng.random((4, 5)).astype(dtype)
  807. B = rng.random((4, 4)).astype(dtype)
  808. alpha = dtype(1)
  809. assert_raises(Exception, func, alpha, A, B)
  810. assert_raises(Exception, func, alpha, A.T, B)
  811. n = 8
  812. m = 7
  813. alpha = dtype(-2.5)
  814. if dtype in COMPLEX_DTYPES:
  815. A = (rng.random((m, m)) + rng.random((m, m))*1j) + eye(m)
  816. else:
  817. A = rng.random((m, m)) + eye(m)
  818. A = A.astype(dtype)
  819. Au = triu(A)
  820. Al = tril(A)
  821. B1 = rng.random((m, n)).astype(dtype)
  822. B2 = rng.random((n, m)).astype(dtype)
  823. x1 = func(alpha=alpha, a=A, b=B1)
  824. assert_equal(B1.shape, x1.shape)
  825. x2 = solve(Au, alpha*B1)
  826. assert_allclose(x1, x2, atol=tol)
  827. x1 = func(alpha=alpha, a=A, b=B1, trans_a=1)
  828. x2 = solve(Au.T, alpha*B1)
  829. assert_allclose(x1, x2, atol=tol)
  830. x1 = func(alpha=alpha, a=A, b=B1, trans_a=2)
  831. x2 = solve(Au.conj().T, alpha*B1)
  832. assert_allclose(x1, x2, atol=tol)
  833. x1 = func(alpha=alpha, a=A, b=B1, diag=1)
  834. Au[arange(m), arange(m)] = dtype(1)
  835. x2 = solve(Au, alpha*B1)
  836. assert_allclose(x1, x2, atol=tol)
  837. x1 = func(alpha=alpha, a=A, b=B2, diag=1, side=1)
  838. x2 = solve(Au.conj().T, alpha*B2.conj().T)
  839. assert_allclose(x1, x2.conj().T, atol=tol)
  840. x1 = func(alpha=alpha, a=A, b=B2, diag=1, side=1, lower=1)
  841. Al[arange(m), arange(m)] = dtype(1)
  842. x2 = solve(Al.conj().T, alpha*B2.conj().T)
  843. assert_allclose(x1, x2.conj().T, atol=tol)
  844. @pytest.mark.xfail(run=False,
  845. reason="gh-16930")
  846. def test_gh_169309():
  847. x = np.repeat(10, 9)
  848. actual = scipy.linalg.blas.dnrm2(x, 5, 3, -1)
  849. expected = math.sqrt(500)
  850. assert_allclose(actual, expected)
  851. def test_dnrm2_neg_incx():
  852. # check that dnrm2(..., incx < 0) raises
  853. # XXX: remove the test after the lowest supported BLAS implements
  854. # negative incx (new in LAPACK 3.10)
  855. x = np.repeat(10, 9)
  856. incx = -1
  857. with assert_raises(fblas.__fblas_error):
  858. scipy.linalg.blas.dnrm2(x, 5, 3, incx)