CUDASparseBlas.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. /*
  4. Provides a subset of cuSPARSE functions as templates:
  5. csrgeam2<scalar_t>(...)
  6. where scalar_t is double, float, c10::complex<double> or c10::complex<float>.
  7. The functions are available in at::cuda::sparse namespace.
  8. */
  9. #include <ATen/cuda/CUDAContext.h>
  10. #include <ATen/cuda/CUDASparse.h>
  11. // NOLINTBEGIN(misc-misplaced-const)
  12. namespace at::cuda::sparse {
  13. #define CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t) \
  14. cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \
  15. const cusparseMatDescr_t descrA, int nnzA, \
  16. const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \
  17. const int *csrSortedColIndA, const scalar_t *beta, \
  18. const cusparseMatDescr_t descrB, int nnzB, \
  19. const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \
  20. const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
  21. const scalar_t *csrSortedValC, const int *csrSortedRowPtrC, \
  22. const int *csrSortedColIndC, size_t *pBufferSizeInBytes
  23. template <typename scalar_t>
  24. inline void csrgeam2_bufferSizeExt(
  25. CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t)) {
  26. TORCH_INTERNAL_ASSERT(
  27. false,
  28. "at::cuda::sparse::csrgeam2_bufferSizeExt: not implemented for ",
  29. typeid(scalar_t).name());
  30. }
  31. template <>
  32. void csrgeam2_bufferSizeExt<float>(
  33. CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(float));
  34. template <>
  35. void csrgeam2_bufferSizeExt<double>(
  36. CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(double));
  37. template <>
  38. void csrgeam2_bufferSizeExt<c10::complex<float>>(
  39. CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<float>));
  40. template <>
  41. void csrgeam2_bufferSizeExt<c10::complex<double>>(
  42. CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<double>));
  43. #define CUSPARSE_CSRGEAM2_NNZ_ARGTYPES() \
  44. cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, \
  45. int nnzA, const int *csrSortedRowPtrA, const int *csrSortedColIndA, \
  46. const cusparseMatDescr_t descrB, int nnzB, const int *csrSortedRowPtrB, \
  47. const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
  48. int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *workspace
  49. template <typename scalar_t>
  50. inline void csrgeam2Nnz(CUSPARSE_CSRGEAM2_NNZ_ARGTYPES()) {
  51. TORCH_CUDASPARSE_CHECK(cusparseXcsrgeam2Nnz(
  52. handle,
  53. m,
  54. n,
  55. descrA,
  56. nnzA,
  57. csrSortedRowPtrA,
  58. csrSortedColIndA,
  59. descrB,
  60. nnzB,
  61. csrSortedRowPtrB,
  62. csrSortedColIndB,
  63. descrC,
  64. csrSortedRowPtrC,
  65. nnzTotalDevHostPtr,
  66. workspace));
  67. }
  68. #define CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t) \
  69. cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \
  70. const cusparseMatDescr_t descrA, int nnzA, \
  71. const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \
  72. const int *csrSortedColIndA, const scalar_t *beta, \
  73. const cusparseMatDescr_t descrB, int nnzB, \
  74. const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \
  75. const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
  76. scalar_t *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC, \
  77. void *pBuffer
  78. template <typename scalar_t>
  79. inline void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t)) {
  80. TORCH_INTERNAL_ASSERT(
  81. false,
  82. "at::cuda::sparse::csrgeam2: not implemented for ",
  83. typeid(scalar_t).name());
  84. }
  85. template <>
  86. void csrgeam2<float>(CUSPARSE_CSRGEAM2_ARGTYPES(float));
  87. template <>
  88. void csrgeam2<double>(CUSPARSE_CSRGEAM2_ARGTYPES(double));
  89. template <>
  90. void csrgeam2<c10::complex<float>>(
  91. CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<float>));
  92. template <>
  93. void csrgeam2<c10::complex<double>>(
  94. CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<double>));
  95. #define CUSPARSE_BSRMM_ARGTYPES(scalar_t) \
  96. cusparseHandle_t handle, cusparseDirection_t dirA, \
  97. cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, \
  98. int kb, int nnzb, const scalar_t *alpha, \
  99. const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
  100. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  101. const scalar_t *B, int ldb, const scalar_t *beta, scalar_t *C, int ldc
  102. template <typename scalar_t>
  103. inline void bsrmm(CUSPARSE_BSRMM_ARGTYPES(scalar_t)) {
  104. TORCH_INTERNAL_ASSERT(
  105. false,
  106. "at::cuda::sparse::bsrmm: not implemented for ",
  107. typeid(scalar_t).name());
  108. }
  109. template <>
  110. void bsrmm<float>(CUSPARSE_BSRMM_ARGTYPES(float));
  111. template <>
  112. void bsrmm<double>(CUSPARSE_BSRMM_ARGTYPES(double));
  113. template <>
  114. void bsrmm<c10::complex<float>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<float>));
  115. template <>
  116. void bsrmm<c10::complex<double>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<double>));
  117. #define CUSPARSE_BSRMV_ARGTYPES(scalar_t) \
  118. cusparseHandle_t handle, cusparseDirection_t dirA, \
  119. cusparseOperation_t transA, int mb, int nb, int nnzb, \
  120. const scalar_t *alpha, const cusparseMatDescr_t descrA, \
  121. const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
  122. int blockDim, const scalar_t *x, const scalar_t *beta, scalar_t *y
  123. template <typename scalar_t>
  124. inline void bsrmv(CUSPARSE_BSRMV_ARGTYPES(scalar_t)) {
  125. TORCH_INTERNAL_ASSERT(
  126. false,
  127. "at::cuda::sparse::bsrmv: not implemented for ",
  128. typeid(scalar_t).name());
  129. }
  130. template <>
  131. void bsrmv<float>(CUSPARSE_BSRMV_ARGTYPES(float));
  132. template <>
  133. void bsrmv<double>(CUSPARSE_BSRMV_ARGTYPES(double));
  134. template <>
  135. void bsrmv<c10::complex<float>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<float>));
  136. template <>
  137. void bsrmv<c10::complex<double>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<double>));
  138. #define CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t) \
  139. cusparseHandle_t handle, cusparseDirection_t dirA, \
  140. cusparseOperation_t transA, int mb, int nnzb, \
  141. const cusparseMatDescr_t descrA, scalar_t *bsrValA, \
  142. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  143. bsrsv2Info_t info, int *pBufferSizeInBytes
  144. template <typename scalar_t>
  145. inline void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t)) {
  146. TORCH_INTERNAL_ASSERT(
  147. false,
  148. "at::cuda::sparse::bsrsv2_bufferSize: not implemented for ",
  149. typeid(scalar_t).name());
  150. }
  151. template <>
  152. void bsrsv2_bufferSize<float>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(float));
  153. template <>
  154. void bsrsv2_bufferSize<double>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(double));
  155. template <>
  156. void bsrsv2_bufferSize<c10::complex<float>>(
  157. CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<float>));
  158. template <>
  159. void bsrsv2_bufferSize<c10::complex<double>>(
  160. CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<double>));
  161. #define CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t) \
  162. cusparseHandle_t handle, cusparseDirection_t dirA, \
  163. cusparseOperation_t transA, int mb, int nnzb, \
  164. const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
  165. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  166. bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
  167. template <typename scalar_t>
  168. inline void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t)) {
  169. TORCH_INTERNAL_ASSERT(
  170. false,
  171. "at::cuda::sparse::bsrsv2_analysis: not implemented for ",
  172. typeid(scalar_t).name());
  173. }
  174. template <>
  175. void bsrsv2_analysis<float>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(float));
  176. template <>
  177. void bsrsv2_analysis<double>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(double));
  178. template <>
  179. void bsrsv2_analysis<c10::complex<float>>(
  180. CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<float>));
  181. template <>
  182. void bsrsv2_analysis<c10::complex<double>>(
  183. CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<double>));
  184. #define CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t) \
  185. cusparseHandle_t handle, cusparseDirection_t dirA, \
  186. cusparseOperation_t transA, int mb, int nnzb, const scalar_t *alpha, \
  187. const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
  188. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  189. bsrsv2Info_t info, const scalar_t *x, scalar_t *y, \
  190. cusparseSolvePolicy_t policy, void *pBuffer
  191. template <typename scalar_t>
  192. inline void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t)) {
  193. TORCH_INTERNAL_ASSERT(
  194. false,
  195. "at::cuda::sparse::bsrsv2_solve: not implemented for ",
  196. typeid(scalar_t).name());
  197. }
  198. template <>
  199. void bsrsv2_solve<float>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(float));
  200. template <>
  201. void bsrsv2_solve<double>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(double));
  202. template <>
  203. void bsrsv2_solve<c10::complex<float>>(
  204. CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<float>));
  205. template <>
  206. void bsrsv2_solve<c10::complex<double>>(
  207. CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<double>));
  208. #define CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t) \
  209. cusparseHandle_t handle, cusparseDirection_t dirA, \
  210. cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
  211. int nnzb, const cusparseMatDescr_t descrA, scalar_t *bsrValA, \
  212. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  213. bsrsm2Info_t info, int *pBufferSizeInBytes
  214. template <typename scalar_t>
  215. inline void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t)) {
  216. TORCH_INTERNAL_ASSERT(
  217. false,
  218. "at::cuda::sparse::bsrsm2_bufferSize: not implemented for ",
  219. typeid(scalar_t).name());
  220. }
  221. template <>
  222. void bsrsm2_bufferSize<float>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(float));
  223. template <>
  224. void bsrsm2_bufferSize<double>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(double));
  225. template <>
  226. void bsrsm2_bufferSize<c10::complex<float>>(
  227. CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<float>));
  228. template <>
  229. void bsrsm2_bufferSize<c10::complex<double>>(
  230. CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<double>));
  231. #define CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t) \
  232. cusparseHandle_t handle, cusparseDirection_t dirA, \
  233. cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
  234. int nnzb, const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
  235. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  236. bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
  237. template <typename scalar_t>
  238. inline void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t)) {
  239. TORCH_INTERNAL_ASSERT(
  240. false,
  241. "at::cuda::sparse::bsrsm2_analysis: not implemented for ",
  242. typeid(scalar_t).name());
  243. }
  244. template <>
  245. void bsrsm2_analysis<float>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(float));
  246. template <>
  247. void bsrsm2_analysis<double>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(double));
  248. template <>
  249. void bsrsm2_analysis<c10::complex<float>>(
  250. CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<float>));
  251. template <>
  252. void bsrsm2_analysis<c10::complex<double>>(
  253. CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<double>));
  254. #define CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t) \
  255. cusparseHandle_t handle, cusparseDirection_t dirA, \
  256. cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
  257. int nnzb, const scalar_t *alpha, const cusparseMatDescr_t descrA, \
  258. const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
  259. int blockDim, bsrsm2Info_t info, const scalar_t *B, int ldb, \
  260. scalar_t *X, int ldx, cusparseSolvePolicy_t policy, void *pBuffer
  261. template <typename scalar_t>
  262. inline void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t)) {
  263. TORCH_INTERNAL_ASSERT(
  264. false,
  265. "at::cuda::sparse::bsrsm2_solve: not implemented for ",
  266. typeid(scalar_t).name());
  267. }
  268. template <>
  269. void bsrsm2_solve<float>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(float));
  270. template <>
  271. void bsrsm2_solve<double>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(double));
  272. template <>
  273. void bsrsm2_solve<c10::complex<float>>(
  274. CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<float>));
  275. template <>
  276. void bsrsm2_solve<c10::complex<double>>(
  277. CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<double>));
  278. } // namespace at::cuda::sparse
  279. // NOLINTEND(misc-misplaced-const)
  280. #else
  281. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  282. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)