blas.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. """
  2. Low-level BLAS functions (:mod:`scipy.linalg.blas`)
  3. ===================================================
  4. This module contains low-level functions from the BLAS library.
  5. .. versionadded:: 0.12.0
  6. .. note::
  7. The common ``overwrite_<>`` option in many routines, allows the
  8. input arrays to be overwritten to avoid extra memory allocation.
  9. However this requires the array to satisfy two conditions
  10. which are memory order and the data type to match exactly the
  11. order and the type expected by the routine.
  12. As an example, if you pass a double precision float array to any
  13. ``S....`` routine which expects single precision arguments, f2py
  14. will create an intermediate array to match the argument types and
  15. overwriting will be performed on that intermediate array.
  16. Similarly, if a C-contiguous array is passed, f2py will pass a
  17. FORTRAN-contiguous array internally. Please make sure that these
  18. details are satisfied. More information can be found in the f2py
  19. documentation.
  20. .. warning::
  21. These functions do little to no error checking.
  22. It is possible to cause crashes by mis-using them,
  23. so prefer using the higher-level routines in `scipy.linalg`.
  24. Finding functions
  25. -----------------
  26. .. autosummary::
  27. :toctree: generated/
  28. get_blas_funcs
  29. find_best_blas_type
  30. BLAS Level 1 functions
  31. ----------------------
  32. .. autosummary::
  33. :toctree: generated/
  34. sasum
  35. saxpy
  36. scasum
  37. scnrm2
  38. scopy
  39. sdot
  40. snrm2
  41. srot
  42. srotg
  43. srotm
  44. srotmg
  45. sscal
  46. sswap
  47. dasum
  48. daxpy
  49. dcopy
  50. ddot
  51. dnrm2
  52. drot
  53. drotg
  54. drotm
  55. drotmg
  56. dscal
  57. dswap
  58. dzasum
  59. dznrm2
  60. icamax
  61. idamax
  62. isamax
  63. izamax
  64. caxpy
  65. ccopy
  66. cdotc
  67. cdotu
  68. crotg
  69. cscal
  70. csrot
  71. csscal
  72. cswap
  73. zaxpy
  74. zcopy
  75. zdotc
  76. zdotu
  77. zdrot
  78. zdscal
  79. zrotg
  80. zscal
  81. zswap
  82. BLAS Level 2 functions
  83. ----------------------
  84. .. autosummary::
  85. :toctree: generated/
  86. sgbmv
  87. sgemv
  88. sger
  89. ssbmv
  90. sspmv
  91. sspr
  92. sspr2
  93. ssymv
  94. ssyr
  95. ssyr2
  96. stbmv
  97. stbsv
  98. stpmv
  99. stpsv
  100. strmv
  101. strsv
  102. dgbmv
  103. dgemv
  104. dger
  105. dsbmv
  106. dspmv
  107. dspr
  108. dspr2
  109. dsymv
  110. dsyr
  111. dsyr2
  112. dtbmv
  113. dtbsv
  114. dtpmv
  115. dtpsv
  116. dtrmv
  117. dtrsv
  118. cgbmv
  119. cgemv
  120. cgerc
  121. cgeru
  122. chbmv
  123. chemv
  124. cher
  125. cher2
  126. chpmv
  127. chpr
  128. chpr2
  129. cspmv
  130. cspr
  131. csyr
  132. ctbmv
  133. ctbsv
  134. ctpmv
  135. ctpsv
  136. ctrmv
  137. ctrsv
  138. zgbmv
  139. zgemv
  140. zgerc
  141. zgeru
  142. zhbmv
  143. zhemv
  144. zher
  145. zher2
  146. zhpmv
  147. zhpr
  148. zhpr2
  149. zspmv
  150. zspr
  151. zsyr
  152. ztbmv
  153. ztbsv
  154. ztpmv
  155. ztpsv
  156. ztrmv
  157. ztrsv
  158. BLAS Level 3 functions
  159. ----------------------
  160. .. autosummary::
  161. :toctree: generated/
  162. sgemm
  163. ssymm
  164. ssyr2k
  165. ssyrk
  166. strmm
  167. strsm
  168. dgemm
  169. dsymm
  170. dsyr2k
  171. dsyrk
  172. dtrmm
  173. dtrsm
  174. cgemm
  175. chemm
  176. cher2k
  177. cherk
  178. csymm
  179. csyr2k
  180. csyrk
  181. ctrmm
  182. ctrsm
  183. zgemm
  184. zhemm
  185. zher2k
  186. zherk
  187. zsymm
  188. zsyr2k
  189. zsyrk
  190. ztrmm
  191. ztrsm
  192. """
  193. #
  194. # Author: Pearu Peterson, March 2002
  195. # refactoring by Fabian Pedregosa, March 2010
  196. #
  197. __all__ = ['get_blas_funcs', 'find_best_blas_type']
  198. import numpy as np
  199. import functools
  200. from scipy.linalg import _fblas
  201. try:
  202. from scipy.linalg import _cblas
  203. except ImportError:
  204. _cblas = None
  205. from scipy.__config__ import CONFIG
  206. HAS_ILP64 = CONFIG['Build Dependencies']['blas']['has ilp64']
  207. del CONFIG
  208. _fblas_64 = None
  209. if HAS_ILP64:
  210. from scipy.linalg import _fblas_64
  211. # Expose all functions (only fblas --- cblas is an implementation detail)
  212. empty_module = None
  213. from scipy.linalg._fblas import * # noqa: E402, F403
  214. del empty_module
  215. # all numeric dtypes '?bBhHiIlLqQefdgFDGO' that are safe to be converted to
  216. # single precision float : '?bBhH!!!!!!ef!!!!!!'
  217. # double precision float : '?bBhHiIlLqQefdg!!!!'
  218. # single precision complex : '?bBhH!!!!!!ef!!F!!!'
  219. # double precision complex : '?bBhHiIlLqQefdgFDG!'
  220. _type_score = {x: 1 for x in '?bBhHef'}
  221. _type_score.update({x: 2 for x in 'iIlLqQd'})
  222. # Handle float128(g) and complex256(G) separately in case non-Windows systems.
  223. # On Windows, the values will be rewritten to the same key with the same value.
  224. _type_score.update({'F': 3, 'D': 4, 'g': 2, 'G': 4})
  225. # Final mapping to the actual prefixes and dtypes
  226. _type_conv = {1: ('s', np.dtype('float32')),
  227. 2: ('d', np.dtype('float64')),
  228. 3: ('c', np.dtype('complex64')),
  229. 4: ('z', np.dtype('complex128'))}
  230. # some convenience alias for complex functions
  231. _blas_alias = {'cnrm2': 'scnrm2', 'znrm2': 'dznrm2',
  232. 'cdot': 'cdotc', 'zdot': 'zdotc',
  233. 'cger': 'cgerc', 'zger': 'zgerc',
  234. 'sdotc': 'sdot', 'sdotu': 'sdot',
  235. 'ddotc': 'ddot', 'ddotu': 'ddot'}
  236. def find_best_blas_type(arrays=(), dtype=None):
  237. """Find best-matching BLAS/LAPACK type.
  238. Arrays are used to determine the optimal prefix of BLAS routines.
  239. Parameters
  240. ----------
  241. arrays : sequence of ndarrays, optional
  242. Arrays can be given to determine optimal prefix of BLAS
  243. routines. If not given, double-precision routines will be
  244. used, otherwise the most generic type in arrays will be used.
  245. dtype : str or dtype, optional
  246. Data-type specifier. Not used if `arrays` is non-empty.
  247. Returns
  248. -------
  249. prefix : str
  250. BLAS/LAPACK prefix character.
  251. dtype : dtype
  252. Inferred Numpy data type.
  253. prefer_fortran : bool
  254. Whether to prefer Fortran order routines over C order.
  255. Examples
  256. --------
  257. >>> import numpy as np
  258. >>> import scipy.linalg.blas as bla
  259. >>> rng = np.random.default_rng()
  260. >>> a = rng.random((10,15))
  261. >>> b = np.asfortranarray(a) # Change the memory layout order
  262. >>> bla.find_best_blas_type((a,))
  263. ('d', dtype('float64'), False)
  264. >>> bla.find_best_blas_type((a*1j,))
  265. ('z', dtype('complex128'), False)
  266. >>> bla.find_best_blas_type((b,))
  267. ('d', dtype('float64'), True)
  268. """
  269. dtype = np.dtype(dtype)
  270. max_score = _type_score.get(dtype.char, 5)
  271. prefer_fortran = False
  272. if arrays:
  273. # In most cases, single element is passed through, quicker route
  274. if len(arrays) == 1:
  275. max_score = _type_score.get(arrays[0].dtype.char, 5)
  276. prefer_fortran = arrays[0].flags['FORTRAN']
  277. else:
  278. # use the most generic type in arrays
  279. scores = [_type_score.get(x.dtype.char, 5) for x in arrays]
  280. max_score = max(scores)
  281. ind_max_score = scores.index(max_score)
  282. # safe upcasting for mix of float64 and complex64 --> prefix 'z'
  283. if max_score == 3 and (2 in scores):
  284. max_score = 4
  285. if arrays[ind_max_score].flags['FORTRAN']:
  286. # prefer Fortran for leading array with column major order
  287. prefer_fortran = True
  288. # Get the LAPACK prefix and the corresponding dtype if not fall back
  289. # to 'd' and double precision float.
  290. prefix, dtype = _type_conv.get(max_score, ('d', np.dtype('float64')))
  291. return prefix, dtype, prefer_fortran
  292. def _get_funcs(names, arrays, dtype,
  293. lib_name, fmodule, cmodule,
  294. fmodule_name, cmodule_name, alias,
  295. ilp64=False):
  296. """
  297. Return available BLAS/LAPACK functions.
  298. Used also in lapack.py. See get_blas_funcs for docstring.
  299. """
  300. funcs = []
  301. unpack = False
  302. dtype = np.dtype(dtype)
  303. module1 = (cmodule, cmodule_name)
  304. module2 = (fmodule, fmodule_name)
  305. if isinstance(names, str):
  306. names = (names,)
  307. unpack = True
  308. prefix, dtype, prefer_fortran = find_best_blas_type(arrays, dtype)
  309. if prefer_fortran:
  310. module1, module2 = module2, module1
  311. for name in names:
  312. func_name = prefix + name
  313. func_name = alias.get(func_name, func_name)
  314. func = getattr(module1[0], func_name, None)
  315. module_name = module1[1]
  316. if func is None:
  317. func = getattr(module2[0], func_name, None)
  318. module_name = module2[1]
  319. if func is None:
  320. raise ValueError(
  321. f'{lib_name} function {func_name} could not be found')
  322. func.module_name, func.typecode = module_name, prefix
  323. func.dtype = dtype
  324. if not ilp64:
  325. func.int_dtype = np.dtype(np.intc)
  326. else:
  327. func.int_dtype = np.dtype(np.int64)
  328. func.prefix = prefix # Backward compatibility
  329. funcs.append(func)
  330. if unpack:
  331. return funcs[0]
  332. else:
  333. return funcs
  334. def _memoize_get_funcs(func):
  335. """
  336. Memoized fast path for _get_funcs instances
  337. """
  338. memo = {}
  339. func.memo = memo
  340. @functools.wraps(func)
  341. def getter(names, arrays=(), dtype=None, ilp64=False):
  342. key = (names, dtype, ilp64)
  343. for array in arrays:
  344. # cf. find_blas_funcs
  345. key += (array.dtype.char, array.flags.fortran)
  346. try:
  347. value = memo.get(key)
  348. except TypeError:
  349. # unhashable key etc.
  350. key = None
  351. value = None
  352. if value is not None:
  353. return value
  354. value = func(names, arrays, dtype, ilp64)
  355. if key is not None:
  356. memo[key] = value
  357. return value
  358. return getter
  359. @_memoize_get_funcs
  360. def get_blas_funcs(names, arrays=(), dtype=None, ilp64=False):
  361. """Return available BLAS function objects from names.
  362. Arrays are used to determine the optimal prefix of BLAS routines.
  363. Parameters
  364. ----------
  365. names : str or sequence of str
  366. Name(s) of BLAS functions without type prefix.
  367. arrays : sequence of ndarrays, optional
  368. Arrays can be given to determine optimal prefix of BLAS
  369. routines. If not given, double-precision routines will be
  370. used, otherwise the most generic type in arrays will be used.
  371. dtype : str or dtype, optional
  372. Data-type specifier. Not used if `arrays` is non-empty.
  373. ilp64 : {True, False, 'preferred'}, optional
  374. Whether to return ILP64 routine variant.
  375. Choosing 'preferred' returns ILP64 routine if available,
  376. and otherwise the 32-bit routine. Default: False
  377. Returns
  378. -------
  379. funcs : list
  380. List containing the found function(s).
  381. Notes
  382. -----
  383. This routine automatically chooses between Fortran/C
  384. interfaces. Fortran code is used whenever possible for arrays with
  385. column major order. In all other cases, C code is preferred.
  386. In BLAS, the naming convention is that all functions start with a
  387. type prefix, which depends on the type of the principal
  388. matrix. These can be one of {'s', 'd', 'c', 'z'} for the NumPy
  389. types {float32, float64, complex64, complex128} respectively.
  390. The code and the dtype are stored in attributes `typecode` and `dtype`
  391. of the returned functions.
  392. Examples
  393. --------
  394. >>> import numpy as np
  395. >>> import scipy.linalg as LA
  396. >>> rng = np.random.default_rng()
  397. >>> a = rng.random((3,2))
  398. >>> x_gemv = LA.get_blas_funcs('gemv', (a,))
  399. >>> x_gemv.typecode
  400. 'd'
  401. >>> x_gemv = LA.get_blas_funcs('gemv',(a*1j,))
  402. >>> x_gemv.typecode
  403. 'z'
  404. """
  405. if isinstance(ilp64, str):
  406. if ilp64 == 'preferred':
  407. ilp64 = HAS_ILP64
  408. else:
  409. raise ValueError("Invalid value for 'ilp64'")
  410. if not ilp64:
  411. return _get_funcs(names, arrays, dtype,
  412. "BLAS", _fblas, _cblas, "fblas", "cblas",
  413. _blas_alias, ilp64=False)
  414. else:
  415. if not HAS_ILP64:
  416. raise RuntimeError("BLAS ILP64 routine requested, but Scipy "
  417. "compiled only with 32-bit BLAS")
  418. return _get_funcs(names, arrays, dtype,
  419. "BLAS", _fblas_64, None, "fblas_64", None,
  420. _blas_alias, ilp64=True)