test_fblas.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602
  1. # Test interfaces to fortran blas.
  2. #
  3. # The tests are more of interface than they are of the underlying blas.
  4. # Only very small matrices checked -- N=3 or so.
  5. #
  6. # !! Complex calculations really aren't checked that carefully.
  7. # !! Only real valued complex numbers are used in tests.
  8. from itertools import product
  9. import sys
  10. import numpy as np
  11. from numpy import float32, float64, complex64, complex128, arange, array, \
  12. zeros, shape, transpose, newaxis, common_type, conjugate
  13. from scipy.linalg import _fblas as fblas
  14. from numpy.testing import assert_array_equal, \
  15. assert_allclose, assert_array_almost_equal, assert_
  16. import pytest
  17. # decimal accuracy to require between Python and LAPACK/BLAS calculations
  18. accuracy = 5
  19. # Since numpy.dot likely uses the same blas, use this routine
  20. # to check.
  21. def matrixmultiply(a, b):
  22. if len(b.shape) == 1:
  23. b_is_vector = True
  24. b = b[:, newaxis]
  25. else:
  26. b_is_vector = False
  27. assert_(a.shape[1] == b.shape[0])
  28. c = zeros((a.shape[0], b.shape[1]), common_type(a, b))
  29. for i in range(a.shape[0]):
  30. for j in range(b.shape[1]):
  31. s = 0
  32. for k in range(a.shape[1]):
  33. s += a[i, k] * b[k, j]
  34. c[i, j] = s
  35. if b_is_vector:
  36. c = c.reshape((a.shape[0],))
  37. return c
  38. ##################################################
  39. # Test blas ?axpy
  40. class BaseAxpy:
  41. ''' Mixin class for axpy tests '''
  42. def test_default_a(self):
  43. x = arange(3., dtype=self.dtype)
  44. y = arange(3., dtype=x.dtype)
  45. real_y = x*1.+y
  46. y = self.blas_func(x, y)
  47. assert_array_equal(real_y, y)
  48. def test_simple(self):
  49. x = arange(3., dtype=self.dtype)
  50. y = arange(3., dtype=x.dtype)
  51. real_y = x*3.+y
  52. y = self.blas_func(x, y, a=3.)
  53. assert_array_equal(real_y, y)
  54. def test_x_stride(self):
  55. x = arange(6., dtype=self.dtype)
  56. y = zeros(3, x.dtype)
  57. y = arange(3., dtype=x.dtype)
  58. real_y = x[::2]*3.+y
  59. y = self.blas_func(x, y, a=3., n=3, incx=2)
  60. assert_array_equal(real_y, y)
  61. def test_y_stride(self):
  62. x = arange(3., dtype=self.dtype)
  63. y = zeros(6, x.dtype)
  64. real_y = x*3.+y[::2]
  65. y = self.blas_func(x, y, a=3., n=3, incy=2)
  66. assert_array_equal(real_y, y[::2])
  67. def test_x_and_y_stride(self):
  68. x = arange(12., dtype=self.dtype)
  69. y = zeros(6, x.dtype)
  70. real_y = x[::4]*3.+y[::2]
  71. y = self.blas_func(x, y, a=3., n=3, incx=4, incy=2)
  72. assert_array_equal(real_y, y[::2])
  73. def test_x_bad_size(self):
  74. x = arange(12., dtype=self.dtype)
  75. y = zeros(6, x.dtype)
  76. with pytest.raises(Exception, match='failed for 1st keyword'):
  77. self.blas_func(x, y, n=4, incx=5)
  78. def test_y_bad_size(self):
  79. x = arange(12., dtype=self.dtype)
  80. y = zeros(6, x.dtype)
  81. with pytest.raises(Exception, match='failed for 1st keyword'):
  82. self.blas_func(x, y, n=3, incy=5)
  83. try:
  84. class TestSaxpy(BaseAxpy):
  85. blas_func = fblas.saxpy
  86. dtype = float32
  87. except AttributeError:
  88. class TestSaxpy:
  89. pass
  90. class TestDaxpy(BaseAxpy):
  91. blas_func = fblas.daxpy
  92. dtype = float64
  93. try:
  94. class TestCaxpy(BaseAxpy):
  95. blas_func = fblas.caxpy
  96. dtype = complex64
  97. except AttributeError:
  98. class TestCaxpy:
  99. pass
  100. class TestZaxpy(BaseAxpy):
  101. blas_func = fblas.zaxpy
  102. dtype = complex128
  103. ##################################################
  104. # Test blas ?scal
  105. class BaseScal:
  106. ''' Mixin class for scal testing '''
  107. def test_simple(self):
  108. x = arange(3., dtype=self.dtype)
  109. real_x = x*3.
  110. x = self.blas_func(3., x)
  111. assert_array_equal(real_x, x)
  112. def test_x_stride(self):
  113. x = arange(6., dtype=self.dtype)
  114. real_x = x.copy()
  115. real_x[::2] = x[::2]*array(3., self.dtype)
  116. x = self.blas_func(3., x, n=3, incx=2)
  117. assert_array_equal(real_x, x)
  118. def test_x_bad_size(self):
  119. x = arange(12., dtype=self.dtype)
  120. with pytest.raises(Exception, match='failed for 1st keyword'):
  121. self.blas_func(2., x, n=4, incx=5)
  122. try:
  123. class TestSscal(BaseScal):
  124. blas_func = fblas.sscal
  125. dtype = float32
  126. except AttributeError:
  127. class TestSscal:
  128. pass
  129. class TestDscal(BaseScal):
  130. blas_func = fblas.dscal
  131. dtype = float64
  132. try:
  133. class TestCscal(BaseScal):
  134. blas_func = fblas.cscal
  135. dtype = complex64
  136. except AttributeError:
  137. class TestCscal:
  138. pass
  139. class TestZscal(BaseScal):
  140. blas_func = fblas.zscal
  141. dtype = complex128
  142. ##################################################
  143. # Test blas ?copy
  144. class BaseCopy:
  145. ''' Mixin class for copy testing '''
  146. def test_simple(self):
  147. x = arange(3., dtype=self.dtype)
  148. y = zeros(shape(x), x.dtype)
  149. y = self.blas_func(x, y)
  150. assert_array_equal(x, y)
  151. def test_x_stride(self):
  152. x = arange(6., dtype=self.dtype)
  153. y = zeros(3, x.dtype)
  154. y = self.blas_func(x, y, n=3, incx=2)
  155. assert_array_equal(x[::2], y)
  156. def test_y_stride(self):
  157. x = arange(3., dtype=self.dtype)
  158. y = zeros(6, x.dtype)
  159. y = self.blas_func(x, y, n=3, incy=2)
  160. assert_array_equal(x, y[::2])
  161. def test_x_and_y_stride(self):
  162. x = arange(12., dtype=self.dtype)
  163. y = zeros(6, x.dtype)
  164. y = self.blas_func(x, y, n=3, incx=4, incy=2)
  165. assert_array_equal(x[::4], y[::2])
  166. def test_x_bad_size(self):
  167. x = arange(12., dtype=self.dtype)
  168. y = zeros(6, x.dtype)
  169. with pytest.raises(Exception, match='failed for 1st keyword'):
  170. self.blas_func(x, y, n=4, incx=5)
  171. def test_y_bad_size(self):
  172. x = arange(12., dtype=self.dtype)
  173. y = zeros(6, x.dtype)
  174. with pytest.raises(Exception, match='failed for 1st keyword'):
  175. self.blas_func(x, y, n=3, incy=5)
  176. # def test_y_bad_type(self):
  177. ## Hmmm. Should this work? What should be the output.
  178. # x = arange(3.,dtype=self.dtype)
  179. # y = zeros(shape(x))
  180. # self.blas_func(x,y)
  181. # assert_array_equal(x,y)
  182. try:
  183. class TestScopy(BaseCopy):
  184. blas_func = fblas.scopy
  185. dtype = float32
  186. except AttributeError:
  187. class TestScopy:
  188. pass
  189. class TestDcopy(BaseCopy):
  190. blas_func = fblas.dcopy
  191. dtype = float64
  192. try:
  193. class TestCcopy(BaseCopy):
  194. blas_func = fblas.ccopy
  195. dtype = complex64
  196. except AttributeError:
  197. class TestCcopy:
  198. pass
  199. class TestZcopy(BaseCopy):
  200. blas_func = fblas.zcopy
  201. dtype = complex128
  202. ##################################################
  203. # Test blas ?swap
  204. class BaseSwap:
  205. ''' Mixin class for swap tests '''
  206. def test_simple(self):
  207. x = arange(3., dtype=self.dtype)
  208. y = zeros(shape(x), x.dtype)
  209. desired_x = y.copy()
  210. desired_y = x.copy()
  211. x, y = self.blas_func(x, y)
  212. assert_array_equal(desired_x, x)
  213. assert_array_equal(desired_y, y)
  214. def test_x_stride(self):
  215. x = arange(6., dtype=self.dtype)
  216. y = zeros(3, x.dtype)
  217. desired_x = y.copy()
  218. desired_y = x.copy()[::2]
  219. x, y = self.blas_func(x, y, n=3, incx=2)
  220. assert_array_equal(desired_x, x[::2])
  221. assert_array_equal(desired_y, y)
  222. def test_y_stride(self):
  223. x = arange(3., dtype=self.dtype)
  224. y = zeros(6, x.dtype)
  225. desired_x = y.copy()[::2]
  226. desired_y = x.copy()
  227. x, y = self.blas_func(x, y, n=3, incy=2)
  228. assert_array_equal(desired_x, x)
  229. assert_array_equal(desired_y, y[::2])
  230. def test_x_and_y_stride(self):
  231. x = arange(12., dtype=self.dtype)
  232. y = zeros(6, x.dtype)
  233. desired_x = y.copy()[::2]
  234. desired_y = x.copy()[::4]
  235. x, y = self.blas_func(x, y, n=3, incx=4, incy=2)
  236. assert_array_equal(desired_x, x[::4])
  237. assert_array_equal(desired_y, y[::2])
  238. def test_x_bad_size(self):
  239. x = arange(12., dtype=self.dtype)
  240. y = zeros(6, x.dtype)
  241. with pytest.raises(Exception, match='failed for 1st keyword'):
  242. self.blas_func(x, y, n=4, incx=5)
  243. def test_y_bad_size(self):
  244. x = arange(12., dtype=self.dtype)
  245. y = zeros(6, x.dtype)
  246. with pytest.raises(Exception, match='failed for 1st keyword'):
  247. self.blas_func(x, y, n=3, incy=5)
  248. try:
  249. class TestSswap(BaseSwap):
  250. blas_func = fblas.sswap
  251. dtype = float32
  252. except AttributeError:
  253. class TestSswap:
  254. pass
  255. class TestDswap(BaseSwap):
  256. blas_func = fblas.dswap
  257. dtype = float64
  258. try:
  259. class TestCswap(BaseSwap):
  260. blas_func = fblas.cswap
  261. dtype = complex64
  262. except AttributeError:
  263. class TestCswap:
  264. pass
  265. class TestZswap(BaseSwap):
  266. blas_func = fblas.zswap
  267. dtype = complex128
  268. ##################################################
  269. # Test blas ?gemv
  270. # This will be a mess to test all cases.
  271. class BaseGemv:
  272. ''' Mixin class for gemv tests '''
  273. def get_data(self, x_stride=1, y_stride=1):
  274. rng = np.random.default_rng(1234)
  275. mult = array(1, dtype=self.dtype)
  276. if self.dtype in [complex64, complex128]:
  277. mult = array(1+1j, dtype=self.dtype)
  278. alpha = array(1., dtype=self.dtype) * mult
  279. beta = array(1., dtype=self.dtype) * mult
  280. a = rng.normal(0., 1., (3, 3)).astype(self.dtype) * mult
  281. x = arange(shape(a)[0]*x_stride, dtype=self.dtype) * mult
  282. y = arange(shape(a)[1]*y_stride, dtype=self.dtype) * mult
  283. return alpha, beta, a, x, y
  284. def test_simple(self):
  285. alpha, beta, a, x, y = self.get_data()
  286. desired_y = alpha*matrixmultiply(a, x)+beta*y
  287. y = self.blas_func(alpha, a, x, beta, y)
  288. assert_array_almost_equal(desired_y, y)
  289. def test_default_beta_y(self):
  290. alpha, beta, a, x, y = self.get_data()
  291. desired_y = matrixmultiply(a, x)
  292. y = self.blas_func(1, a, x)
  293. assert_array_almost_equal(desired_y, y)
  294. def test_simple_transpose(self):
  295. alpha, beta, a, x, y = self.get_data()
  296. desired_y = alpha*matrixmultiply(transpose(a), x)+beta*y
  297. y = self.blas_func(alpha, a, x, beta, y, trans=1)
  298. assert_array_almost_equal(desired_y, y)
  299. def test_simple_transpose_conj(self):
  300. alpha, beta, a, x, y = self.get_data()
  301. desired_y = alpha*matrixmultiply(transpose(conjugate(a)), x)+beta*y
  302. y = self.blas_func(alpha, a, x, beta, y, trans=2)
  303. assert_array_almost_equal(desired_y, y)
  304. def test_x_stride(self):
  305. alpha, beta, a, x, y = self.get_data(x_stride=2)
  306. desired_y = alpha*matrixmultiply(a, x[::2])+beta*y
  307. y = self.blas_func(alpha, a, x, beta, y, incx=2)
  308. assert_array_almost_equal(desired_y, y)
  309. def test_x_stride_transpose(self):
  310. alpha, beta, a, x, y = self.get_data(x_stride=2)
  311. desired_y = alpha*matrixmultiply(transpose(a), x[::2])+beta*y
  312. y = self.blas_func(alpha, a, x, beta, y, trans=1, incx=2)
  313. assert_array_almost_equal(desired_y, y)
  314. def test_x_stride_assert(self):
  315. # What is the use of this test?
  316. alpha, beta, a, x, y = self.get_data(x_stride=2)
  317. with pytest.raises(Exception, match='failed for 3rd argument'):
  318. y = self.blas_func(1, a, x, 1, y, trans=0, incx=3)
  319. with pytest.raises(Exception, match='failed for 3rd argument'):
  320. y = self.blas_func(1, a, x, 1, y, trans=1, incx=3)
  321. def test_y_stride(self):
  322. alpha, beta, a, x, y = self.get_data(y_stride=2)
  323. desired_y = y.copy()
  324. desired_y[::2] = alpha*matrixmultiply(a, x)+beta*y[::2]
  325. y = self.blas_func(alpha, a, x, beta, y, incy=2)
  326. assert_array_almost_equal(desired_y, y)
  327. def test_y_stride_transpose(self):
  328. alpha, beta, a, x, y = self.get_data(y_stride=2)
  329. desired_y = y.copy()
  330. desired_y[::2] = alpha*matrixmultiply(transpose(a), x)+beta*y[::2]
  331. y = self.blas_func(alpha, a, x, beta, y, trans=1, incy=2)
  332. assert_array_almost_equal(desired_y, y)
  333. def test_y_stride_assert(self):
  334. # What is the use of this test?
  335. alpha, beta, a, x, y = self.get_data(y_stride=2)
  336. with pytest.raises(Exception, match='failed for 2nd keyword'):
  337. y = self.blas_func(1, a, x, 1, y, trans=0, incy=3)
  338. with pytest.raises(Exception, match='failed for 2nd keyword'):
  339. y = self.blas_func(1, a, x, 1, y, trans=1, incy=3)
  340. try:
  341. class TestSgemv(BaseGemv):
  342. blas_func = fblas.sgemv
  343. dtype = float32
  344. @pytest.mark.skipif(sys.platform != 'darwin', reason="MacOS specific test")
  345. def test_sgemv_on_osx(self):
  346. def aligned_array(shape, align, dtype, order='C'):
  347. # Make array shape `shape` with aligned at `align` bytes
  348. d = dtype()
  349. # Make array of correct size with `align` extra bytes
  350. N = np.prod(shape)
  351. tmp = np.zeros(N * d.nbytes + align, dtype=np.uint8)
  352. address = tmp.__array_interface__["data"][0]
  353. # Find offset into array giving desired alignment
  354. for offset in range(align):
  355. if (address + offset) % align == 0:
  356. break
  357. tmp = tmp[offset:offset+N*d.nbytes].view(dtype=dtype)
  358. return tmp.reshape(shape, order=order)
  359. def as_aligned(arr, align, dtype, order='C'):
  360. # Copy `arr` into an aligned array with same shape
  361. aligned = aligned_array(arr.shape, align, dtype, order)
  362. aligned[:] = arr[:]
  363. return aligned
  364. def assert_dot_close(A, X, desired):
  365. assert_allclose(self.blas_func(1.0, A, X), desired,
  366. rtol=1e-5, atol=1e-7)
  367. testdata = product((15, 32), (10000,), (200, 89), ('C', 'F'))
  368. rng = np.random.default_rng(1234)
  369. for align, m, n, a_order in testdata:
  370. A_d = rng.random((m, n))
  371. X_d = rng.random(n)
  372. desired = np.dot(A_d, X_d)
  373. # Calculation with aligned single precision
  374. A_f = as_aligned(A_d, align, np.float32, order=a_order)
  375. X_f = as_aligned(X_d, align, np.float32, order=a_order)
  376. assert_dot_close(A_f, X_f, desired)
  377. except AttributeError:
  378. class TestSgemv:
  379. pass
  380. class TestDgemv(BaseGemv):
  381. blas_func = fblas.dgemv
  382. dtype = float64
  383. try:
  384. class TestCgemv(BaseGemv):
  385. blas_func = fblas.cgemv
  386. dtype = complex64
  387. except AttributeError:
  388. class TestCgemv:
  389. pass
  390. class TestZgemv(BaseGemv):
  391. blas_func = fblas.zgemv
  392. dtype = complex128
  393. """
  394. ##################################################
  395. ### Test blas ?ger
  396. ### This will be a mess to test all cases.
  397. class BaseGer:
  398. def get_data(self,x_stride=1,y_stride=1):
  399. rng = np.random.default_rng(1234)
  400. alpha = array(1., dtype = self.dtype)
  401. a = rng.normal(0.,1.,(3,3)).astype(self.dtype)
  402. x = arange(shape(a)[0]*x_stride,dtype=self.dtype)
  403. y = arange(shape(a)[1]*y_stride,dtype=self.dtype)
  404. return alpha,a,x,y
  405. def test_simple(self):
  406. alpha,a,x,y = self.get_data()
  407. # transpose takes care of Fortran vs. C(and Python) memory layout
  408. desired_a = alpha*transpose(x[:,newaxis]*y) + a
  409. self.blas_func(x,y,a)
  410. assert_array_almost_equal(desired_a,a)
  411. def test_x_stride(self):
  412. alpha,a,x,y = self.get_data(x_stride=2)
  413. desired_a = alpha*transpose(x[::2,newaxis]*y) + a
  414. self.blas_func(x,y,a,incx=2)
  415. assert_array_almost_equal(desired_a,a)
  416. def test_x_stride_assert(self):
  417. alpha,a,x,y = self.get_data(x_stride=2)
  418. with pytest.raises(ValueError, match='foo'):
  419. self.blas_func(x,y,a,incx=3)
  420. def test_y_stride(self):
  421. alpha,a,x,y = self.get_data(y_stride=2)
  422. desired_a = alpha*transpose(x[:,newaxis]*y[::2]) + a
  423. self.blas_func(x,y,a,incy=2)
  424. assert_array_almost_equal(desired_a,a)
  425. def test_y_stride_assert(self):
  426. alpha,a,x,y = self.get_data(y_stride=2)
  427. with pytest.raises(ValueError, match='foo'):
  428. self.blas_func(a,x,y,incy=3)
  429. class TestSger(BaseGer):
  430. blas_func = fblas.sger
  431. dtype = float32
  432. class TestDger(BaseGer):
  433. blas_func = fblas.dger
  434. dtype = float64
  435. """
  436. ##################################################
  437. # Test blas ?gerc
  438. # This will be a mess to test all cases.
  439. """
  440. class BaseGerComplex(BaseGer):
  441. def get_data(self,x_stride=1,y_stride=1):
  442. rng = np.random.default_rng(1234)
  443. alpha = array(1+1j, dtype = self.dtype)
  444. a = rng.normal(0.,1.,(3,3)).astype(self.dtype)
  445. a = a + rng.normal(0.,1.,(3,3)) * array(1j, dtype = self.dtype)
  446. x = rng.normal(0.,1.,shape(a)[0]*x_stride).astype(self.dtype)
  447. x = x + x * array(1j, dtype = self.dtype)
  448. y = rng.normal(0.,1.,shape(a)[1]*y_stride).astype(self.dtype)
  449. y = y + y * array(1j, dtype = self.dtype)
  450. return alpha,a,x,y
  451. def test_simple(self):
  452. alpha,a,x,y = self.get_data()
  453. # transpose takes care of Fortran vs. C(and Python) memory layout
  454. a = a * array(0.,dtype = self.dtype)
  455. #desired_a = alpha*transpose(x[:,newaxis]*self.transform(y)) + a
  456. desired_a = alpha*transpose(x[:,newaxis]*y) + a
  457. #self.blas_func(x,y,a,alpha = alpha)
  458. fblas.cgeru(x,y,a,alpha = alpha)
  459. assert_array_almost_equal(desired_a,a)
  460. #def test_x_stride(self):
  461. # alpha,a,x,y = self.get_data(x_stride=2)
  462. # desired_a = alpha*transpose(x[::2,newaxis]*self.transform(y)) + a
  463. # self.blas_func(x,y,a,incx=2)
  464. # assert_array_almost_equal(desired_a,a)
  465. #def test_y_stride(self):
  466. # alpha,a,x,y = self.get_data(y_stride=2)
  467. # desired_a = alpha*transpose(x[:,newaxis]*self.transform(y[::2])) + a
  468. # self.blas_func(x,y,a,incy=2)
  469. # assert_array_almost_equal(desired_a,a)
  470. class TestCgeru(BaseGerComplex):
  471. blas_func = fblas.cgeru
  472. dtype = complex64
  473. def transform(self,x):
  474. return x
  475. class TestZgeru(BaseGerComplex):
  476. blas_func = fblas.zgeru
  477. dtype = complex128
  478. def transform(self,x):
  479. return x
  480. class TestCgerc(BaseGerComplex):
  481. blas_func = fblas.cgerc
  482. dtype = complex64
  483. def transform(self,x):
  484. return conjugate(x)
  485. class TestZgerc(BaseGerComplex):
  486. blas_func = fblas.zgerc
  487. dtype = complex128
  488. def transform(self,x):
  489. return conjugate(x)
  490. """