CUDABlas.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. /*
  4. Provides a subset of CUDA BLAS functions as templates:
  5. gemm<Dtype>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
  6. ldc)
  7. gemv<Dtype>(transa, m, n, alpha, a, lda, x, incx, beta, y, incy)
  8. dot<Dtype>(n, x, incx, y, incy, result)
  9. where Dtype is double, float, at::Half or at::BFloat16 (ROCm, NOT for dot).
  10. The functions are available in at::cuda::blas namespace.
  11. */
  12. #include <ATen/cuda/CUDAContext.h>
  13. #include <ATen/BlasBackend.h>
  14. #include <ATen/OpMathType.h>
  15. namespace at::cuda::blas {
  16. // RAII guard that sets the CuBLAS pointer mode and restores it to
  17. // its previous value when the guard is destroyed
  18. class PointerModeGuard {
  19. public:
  20. PointerModeGuard(cublasHandle_t handle, cublasPointerMode_t mode) :
  21. handle(handle) {
  22. TORCH_CUDABLAS_CHECK(cublasGetPointerMode(handle, &previous_mode));
  23. TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, mode));
  24. }
  25. ~PointerModeGuard() {
  26. cublasSetPointerMode(handle, previous_mode);
  27. }
  28. private:
  29. cublasHandle_t handle;
  30. cublasPointerMode_t previous_mode{};
  31. };
  32. /* LEVEL 3 BLAS FUNCTIONS */
  33. #define CUDABLAS_GEMM_ARGTYPES(Dtype) CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype)
  34. #define CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype) \
  35. char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
  36. const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, at::opmath_type<Dtype> beta,\
  37. C_Dtype *c, int64_t ldc
  38. #define CUDABLAS_GEMM_ARGS(Dtype) transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc
  39. #define CUDABLAS_GEMM_DTYPE_IS_FLOAT_TYPE_AND_C_DTYPE_IS_FLOAT \
  40. ((std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value) && std::is_same<C_Dtype, float>::value)
  41. template <typename Dtype, typename C_Dtype = Dtype, typename std::enable_if<!CUDABLAS_GEMM_DTYPE_IS_FLOAT_TYPE_AND_C_DTYPE_IS_FLOAT, Dtype>::type* = nullptr>
  42. inline void gemm(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) {
  43. static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm: not implemented");
  44. }
  45. template <typename Dtype, typename C_Dtype, typename std::enable_if<CUDABLAS_GEMM_DTYPE_IS_FLOAT_TYPE_AND_C_DTYPE_IS_FLOAT, Dtype>::type* = nullptr>
  46. void gemm(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype));
  47. template <>
  48. void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double));
  49. template <>
  50. void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float));
  51. template <>
  52. void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
  53. template <>
  54. void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
  55. template <>
  56. void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
  57. template <>
  58. void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
  59. template<>
  60. void gemm<at::Half, float>(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half, float));
  61. template<>
  62. void gemm<at::BFloat16, float>(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float));
  63. template <typename Dtype, typename C_Dtype = Dtype>
  64. inline void gemm_internal(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) {
  65. static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm_internal: not implemented");
  66. }
  67. template <>
  68. void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double));
  69. template <>
  70. void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float));
  71. template <>
  72. void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
  73. template <>
  74. void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
  75. template <>
  76. void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
  77. template <>
  78. void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
  79. template<>
  80. void gemm_internal<at::Half, float>(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half, float));
  81. template<>
  82. void gemm_internal<at::BFloat16, float>(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float));
  83. enum GEMMAndBiasActivationEpilogue {
  84. None,
  85. RELU,
  86. GELU,
  87. };
  88. // NOTE: GELU activation is not supported prior to CUDA 11.4 and will
  89. // do nothing if passed in that case.
  90. template <typename Dtype, typename C_Dtype = Dtype>
  91. bool gemm_and_bias(
  92. bool transpose_mat1,
  93. bool transpose_mat2,
  94. int64_t m,
  95. int64_t n,
  96. int64_t k,
  97. at::opmath_type<Dtype> alpha_val,
  98. const Dtype* mat1_ptr,
  99. int64_t mat1_ld,
  100. const Dtype* mat2_ptr,
  101. int64_t mat2_ld,
  102. const Dtype* bias,
  103. C_Dtype* result_ptr,
  104. int64_t result_ld,
  105. GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::None);
  106. void int8_gemm(
  107. bool transpose_mat1,
  108. bool transpose_mat2,
  109. int64_t m,
  110. int64_t n,
  111. int64_t k,
  112. const int8_t* mat1_ptr,
  113. int64_t mat1_ld,
  114. const int8_t* mat2_ptr,
  115. int64_t mat2_ld,
  116. int32_t* result_ptr,
  117. int64_t result_ld);
  118. void scaled_gemm(
  119. char transa,
  120. char transb,
  121. int64_t m,
  122. int64_t n,
  123. int64_t k,
  124. const void* mat1_ptr,
  125. const void* mat1_scale_ptr,
  126. int64_t mat1_ld,
  127. ScalarType mat1_dtype,
  128. ScalarType mat1_scale_dtype,
  129. at::blas::ScalingType mat1_scaling_type,
  130. const void* mat2_ptr,
  131. const void* mat2_scale_ptr,
  132. int64_t mat2_ld,
  133. ScalarType mat2_dtype,
  134. ScalarType mat2_scale_dtype,
  135. at::blas::ScalingType mat2_scaling_type,
  136. const void* bias_ptr,
  137. ScalarType bias_dtype,
  138. void* result_ptr,
  139. const void* result_scale_ptr,
  140. int64_t result_ld,
  141. ScalarType result_dtype,
  142. bool use_fast_accum,
  143. const std::optional<Tensor>& alpha);
  144. #define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype)
  145. #define CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype) \
  146. char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
  147. const Dtype *a, int64_t lda, int64_t stridea, \
  148. const Dtype *b, int64_t ldb, int64_t strideb, \
  149. at::opmath_type<Dtype> beta, C_Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches
  150. #define CUDABLAS_BGEMM_ARGS(Dtype) \
  151. transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, num_batches
  152. template <typename Dtype, typename C_Dtype = Dtype, typename std::enable_if<!CUDABLAS_GEMM_DTYPE_IS_FLOAT_TYPE_AND_C_DTYPE_IS_FLOAT, Dtype>::type* = nullptr>
  153. inline void bgemm(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) {
  154. static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm: not implemented");
  155. }
  156. template <typename Dtype, typename C_Dtype, typename std::enable_if<CUDABLAS_GEMM_DTYPE_IS_FLOAT_TYPE_AND_C_DTYPE_IS_FLOAT, Dtype>::type* = nullptr>
  157. void bgemm(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype));
  158. template <>
  159. void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double));
  160. template <>
  161. void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float));
  162. template <>
  163. void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
  164. template <>
  165. void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
  166. template <>
  167. void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
  168. template <>
  169. void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
  170. template<>
  171. void bgemm<at::Half, float>(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::Half, float));
  172. template<>
  173. void bgemm<at::BFloat16, float>(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float));
  174. template <typename Dtype, typename C_Dtype = Dtype>
  175. inline void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) {
  176. static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm_internal: not implemented");
  177. }
  178. template <>
  179. void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double));
  180. template <>
  181. void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float));
  182. template <>
  183. void bgemm_internal<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
  184. template <>
  185. void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
  186. template <>
  187. void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
  188. template <>
  189. void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
  190. template<>
  191. void bgemm_internal<at::Half, float>(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::Half, float));
  192. template<>
  193. void bgemm_internal<at::BFloat16, float>(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float));
  194. #define CUDABLAS_TRSM_ARGTYPES(Dtype) \
  195. cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
  196. cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \
  197. const Dtype *alpha, const Dtype *A, int lda, Dtype *B, int ldb
  198. template <typename Dtype>
  199. inline void trsm(CUDABLAS_TRSM_ARGTYPES(Dtype)) {
  200. static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsm: not implemented");
  201. }
  202. template <>
  203. TORCH_CUDA_CU_API void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float));
  204. template <>
  205. TORCH_CUDA_CU_API void trsm<double>(CUDABLAS_TRSM_ARGTYPES(double));
  206. template <>
  207. TORCH_CUDA_CU_API void trsm<c10::complex<float>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<float>));
  208. template <>
  209. TORCH_CUDA_CU_API void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>));
  210. #define CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype) \
  211. cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
  212. cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \
  213. const Dtype *alpha, Dtype *A[], int lda, Dtype *B[], int ldb, \
  214. int batchCount
  215. template <typename Dtype>
  216. inline void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype)) {
  217. static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsmBatched: not implemented");
  218. }
  219. template <>
  220. TORCH_CUDA_CU_API void trsmBatched<float>(CUDABLAS_TRSM_BATCHED_ARGTYPES(float));
  221. template <>
  222. TORCH_CUDA_CU_API void trsmBatched<double>(CUDABLAS_TRSM_BATCHED_ARGTYPES(double));
  223. template <>
  224. TORCH_CUDA_CU_API void trsmBatched<c10::complex<float>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<float>));
  225. template <>
  226. TORCH_CUDA_CU_API void trsmBatched<c10::complex<double>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<double>));
  227. /* LEVEL 2 BLAS FUNCTIONS */
  228. #define CUDABLAS_GEMV_ARGTYPES(Dtype) \
  229. char trans, int64_t m, int64_t n, Dtype alpha, const Dtype *a, int64_t lda, \
  230. const Dtype *x, int64_t incx, Dtype beta, Dtype *y, int64_t incy
  231. template <typename Dtype>
  232. inline void gemv(CUDABLAS_GEMV_ARGTYPES(Dtype)) {
  233. static_assert(false&&sizeof(Dtype), "at::cuda::blas::gemv: not implemented");
  234. }
  235. template <>
  236. void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double));
  237. template <>
  238. void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float));
  239. template <>
  240. void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>));
  241. template <>
  242. void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>));
  243. template <>
  244. void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half));
  245. template <>
  246. void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16));
  247. /* LEVEL 1 BLAS FUNCTIONS */
  248. #define CUDABLAS_DOT_ARGTYPES(Dtype) \
  249. cublasHandle_t handle, int n, const Dtype *x, int incx, const Dtype *y, \
  250. int incy, Dtype *result
  251. template <typename Dtype>
  252. inline void dot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
  253. static_assert(false&&sizeof(Dtype),"at::cuda::blas::dot: not implemented");
  254. }
  255. template <>
  256. void dot<double>(CUDABLAS_DOT_ARGTYPES(double));
  257. template <>
  258. void dot<float>(CUDABLAS_DOT_ARGTYPES(float));
  259. template <>
  260. void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half));
  261. template <>
  262. void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16));
  263. template <>
  264. void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
  265. template <>
  266. void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
  267. template <typename Dtype>
  268. inline void vdot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
  269. static_assert(false&&sizeof(Dtype),"at::cuda::blas::vdot: not implemented");
  270. }
  271. template <>
  272. void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
  273. template <>
  274. void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
  275. #define CUDABLAS_GETRS_ARGTYPES(Dtype) \
  276. cublasHandle_t handle, cublasOperation_t trans, \
  277. int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \
  278. Dtype** dB_array, int ldb, int* info_array, int batchsize
  279. #define CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype) \
  280. cublasHandle_t handle, int m, int n, Dtype **A_array, int lda, \
  281. Dtype **tau_array, int *info, int batchsize
  282. #define CUDABLAS_GETRF_ARGTYPES(Dtype) \
  283. int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize
  284. #define CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype) \
  285. cublasHandle_t handle, cublasOperation_t trans, \
  286. int m, int n, int nrhs, Dtype** dA_array, int ldda, \
  287. Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize
  288. template<class Dtype>
  289. void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) {
  290. static_assert(false&&sizeof(Dtype),"at::cuda::blas::getrsBatched: not implemented");
  291. }
  292. template<>
  293. TORCH_CUDA_CU_API void getrsBatched<float>(CUDABLAS_GETRS_ARGTYPES(float));
  294. template<>
  295. TORCH_CUDA_CU_API void getrsBatched<double>(CUDABLAS_GETRS_ARGTYPES(double));
  296. template<>
  297. TORCH_CUDA_CU_API void getrsBatched<c10::complex<float>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<float>));
  298. template<>
  299. TORCH_CUDA_CU_API void getrsBatched<c10::complex<double>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<double>));
  300. template <class Dtype>
  301. void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) {
  302. static_assert(false&&sizeof(Dtype), "at::cuda::blas::geqrfBatched: not implemented");
  303. }
  304. template <>
  305. TORCH_CUDA_CU_API void geqrfBatched<float>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float));
  306. template <>
  307. TORCH_CUDA_CU_API void geqrfBatched<double>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double));
  308. template <>
  309. TORCH_CUDA_CU_API void geqrfBatched<c10::complex<double>>(
  310. CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<double>));
  311. template <>
  312. TORCH_CUDA_CU_API void geqrfBatched<c10::complex<float>>(
  313. CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<float>));
  314. template<class Dtype>
  315. void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) {
  316. static_assert(false&&sizeof(Dtype), "at::cuda::blas::getrfBatched: not implemented");
  317. }
  318. template<>
  319. TORCH_CUDA_CU_API void getrfBatched<float>(CUDABLAS_GETRF_ARGTYPES(float));
  320. template<>
  321. TORCH_CUDA_CU_API void getrfBatched<double>(CUDABLAS_GETRF_ARGTYPES(double));
  322. template<>
  323. TORCH_CUDA_CU_API void getrfBatched<c10::complex<double>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<double>));
  324. template<>
  325. TORCH_CUDA_CU_API void getrfBatched<c10::complex<float>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<float>));
  326. template <class Dtype>
  327. void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) {
  328. static_assert(false&&sizeof(Dtype), "at::cuda::blas::gelsBatched: not implemented");
  329. }
  330. template<>
  331. TORCH_CUDA_CU_API void gelsBatched<double>(CUDABLAS_GELS_BATCHED_ARGTYPES(double));
  332. template<>
  333. TORCH_CUDA_CU_API void gelsBatched<float>(CUDABLAS_GELS_BATCHED_ARGTYPES(float));
  334. template<>
  335. TORCH_CUDA_CU_API void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<double>));
  336. template<>
  337. TORCH_CUDA_CU_API void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>));
  338. } // namespace at::cuda::blas
  339. #else
  340. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  341. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)