CPUBlas.h 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/OpMathType.h>
  4. #include <ATen/native/DispatchStub.h>
  5. #include <ATen/native/TransposeType.h>
  6. #include <c10/util/complex.h>
  7. #include <c10/core/ScalarType.h>
  8. #include <c10/core/Scalar.h>
  9. namespace at::native::cpublas {
  10. namespace internal {
  11. void normalize_last_dims(
  12. TransposeType transa, TransposeType transb,
  13. int64_t m, int64_t n, int64_t k,
  14. int64_t *lda, int64_t *ldb, int64_t *ldc);
  15. } // namespace internal
  16. using gemm_fn = void(*)(
  17. at::ScalarType type,
  18. TransposeType transa, TransposeType transb,
  19. int64_t m, int64_t n, int64_t k,
  20. const Scalar& alpha,
  21. const void *a, int64_t lda,
  22. const void *b, int64_t ldb,
  23. const Scalar& beta,
  24. void *c, int64_t ldc);
  25. DECLARE_DISPATCH(gemm_fn, gemm_stub)
  26. using gemm_no_downcast_fn = void(*)(
  27. at::ScalarType type,
  28. TransposeType transa, TransposeType transb,
  29. int64_t m, int64_t n, int64_t k,
  30. const Scalar& alpha,
  31. const void *a, int64_t lda,
  32. const void *b, int64_t ldb,
  33. const Scalar& beta,
  34. void *c, int64_t ldc);
  35. DECLARE_DISPATCH(gemm_no_downcast_fn, gemm_no_downcast_stub)
  36. template <typename scalar_t>
  37. void gemm(
  38. TransposeType transa, TransposeType transb,
  39. int64_t m, int64_t n, int64_t k,
  40. at::opmath_type<scalar_t> alpha,
  41. const scalar_t *a, int64_t lda,
  42. const scalar_t *b, int64_t ldb,
  43. at::opmath_type<scalar_t> beta,
  44. scalar_t *c, int64_t ldc) {
  45. internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
  46. gemm_stub(
  47. kCPU, c10::CppTypeToScalarType<scalar_t>::value,
  48. transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
  49. }
  50. void gemm(
  51. TransposeType transa, TransposeType transb,
  52. int64_t m, int64_t n, int64_t k,
  53. double alpha,
  54. const double *a, int64_t lda,
  55. const double *b, int64_t ldb,
  56. double beta,
  57. double *c, int64_t ldc);
  58. void gemm(
  59. TransposeType transa, TransposeType transb,
  60. int64_t m, int64_t n, int64_t k,
  61. float alpha,
  62. const float *a, int64_t lda,
  63. const float *b, int64_t ldb,
  64. float beta,
  65. float *c, int64_t ldc);
  66. void gemm(
  67. TransposeType transa, TransposeType transb,
  68. int64_t m, int64_t n, int64_t k,
  69. float alpha,
  70. const at::BFloat16 *a, int64_t lda,
  71. const at::BFloat16 *b, int64_t ldb,
  72. float beta,
  73. at::BFloat16 *c, int64_t ldc);
  74. void gemm(
  75. TransposeType transa, TransposeType transb,
  76. int64_t m, int64_t n, int64_t k,
  77. const float alpha,
  78. const at::BFloat16 *a, int64_t lda,
  79. const at::BFloat16 *b, int64_t ldb,
  80. const float beta,
  81. float *c, int64_t ldc);
  82. void gemm(
  83. TransposeType transa, TransposeType transb,
  84. int64_t m, int64_t n, int64_t k,
  85. float alpha,
  86. const at::Half *a, int64_t lda,
  87. const at::Half *b, int64_t ldb,
  88. float beta,
  89. at::Half *c, int64_t ldc);
  90. void gemm(
  91. TransposeType transa, TransposeType transb,
  92. int64_t m, int64_t n, int64_t k,
  93. const float alpha,
  94. const at::Half *a, int64_t lda,
  95. const at::Half *b, int64_t ldb,
  96. const float beta,
  97. float *c, int64_t ldc);
  98. void gemm(
  99. TransposeType transa, TransposeType transb,
  100. int64_t m, int64_t n, int64_t k,
  101. c10::complex<double> alpha,
  102. const c10::complex<double> *a, int64_t lda,
  103. const c10::complex<double> *b, int64_t ldb,
  104. c10::complex<double> beta,
  105. c10::complex<double> *c, int64_t ldc);
  106. void gemm(
  107. TransposeType transa, TransposeType transb,
  108. int64_t m, int64_t n, int64_t k,
  109. c10::complex<float> alpha,
  110. const c10::complex<float> *a, int64_t lda,
  111. const c10::complex<float> *b, int64_t ldb,
  112. c10::complex<float> beta,
  113. c10::complex<float> *c, int64_t ldc);
  114. void gemm(
  115. TransposeType transa, TransposeType transb,
  116. int64_t m, int64_t n, int64_t k,
  117. int64_t alpha,
  118. const int64_t *a, int64_t lda,
  119. const int64_t *b, int64_t ldb,
  120. int64_t beta,
  121. int64_t *c, int64_t ldc);
  122. template <typename scalar_t>
  123. void gemm_batched(
  124. TransposeType transa, TransposeType transb,
  125. int64_t batch_size, int64_t m, int64_t n, int64_t k,
  126. scalar_t alpha,
  127. const scalar_t * const *a, int64_t lda,
  128. const scalar_t * const *b, int64_t ldb,
  129. const scalar_t beta,
  130. scalar_t * const *c, int64_t ldc);
  131. template <typename scalar_t>
  132. void gemm_batched_with_stride(
  133. TransposeType transa, TransposeType transb,
  134. int64_t batch_size, int64_t m, int64_t n, int64_t k,
  135. scalar_t alpha,
  136. const scalar_t *a, int64_t lda, int64_t batch_stride_a,
  137. const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
  138. scalar_t beta,
  139. scalar_t *c, int64_t ldc, int64_t batch_stride_c);
  140. using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy);
  141. DECLARE_DISPATCH(axpy_fn, axpy_stub)
  142. template<typename scalar_t>
  143. void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){
  144. if(n == 1)
  145. {
  146. incx = 1;
  147. incy = 1;
  148. }
  149. axpy_stub(
  150. kCPU, c10::CppTypeToScalarType<scalar_t>::value,
  151. n, a, x, incx, y, incy);
  152. }
  153. void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy);
  154. void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy);
  155. void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
  156. void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
  157. using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy);
  158. DECLARE_DISPATCH(copy_fn, copy_stub)
  159. template<typename scalar_t>
  160. void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) {
  161. if(n == 1)
  162. {
  163. incx = 1;
  164. incy = 1;
  165. }
  166. copy_stub(
  167. kCPU, c10::CppTypeToScalarType<scalar_t>::value,
  168. n, x, incx, y, incy);
  169. }
  170. void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy);
  171. void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
  172. void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
  173. void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
  174. // Batch-reduce GEMM
  175. // Operates by the following formula:
  176. // C = SUM(A[i] x B[i]) + C if add_C is true, i = 0 to batch size
  177. // A Base pointer to a tensor A.
  178. // B Base pointer to a tensor B.
  179. // C Pointer to a tensor C (accumulation buffer).
  180. // Note only batch size 1 is used currently
  181. // Define macros for available brgemm APIs
  182. // so that callers can determine which APIs are available
  183. #define CPUBLAS_BRGEMM_F16F16F32 // half * half -> float
  184. #define CPUBLAS_BRGEMM_BF16BF16F32 // bfloat16 * bfloat16 -> float
  185. #define CPUBLAS_BRGEMM_F32F32F32 // float * float -> float
  186. #define CPUBLAS_BRGEMM_U8U8I32 // unsigned char * unsigned char -> int32
  187. #define CPUBLAS_BRGEMM_U8I8I32 // unsigned char * signed char -> int32
  188. #define CPUBLAS_BRGEMM_I8I8I32 // signed char * signed char -> int32
  189. TORCH_API void brgemm(
  190. int64_t M,
  191. int64_t N,
  192. int64_t K,
  193. int64_t ld_a,
  194. int64_t ld_b,
  195. int64_t ld_c,
  196. const bool add_C,
  197. const at::Half* A,
  198. const at::Half* B,
  199. float* C,
  200. bool is_vnni = true);
  201. TORCH_API void brgemm(
  202. int64_t M,
  203. int64_t N,
  204. int64_t K,
  205. int64_t ld_a,
  206. int64_t ld_b,
  207. int64_t ld_c,
  208. const bool add_C,
  209. const at::BFloat16* A,
  210. const at::BFloat16* B,
  211. float* C,
  212. bool is_vnni = true);
  213. TORCH_API void brgemm(
  214. int64_t M,
  215. int64_t N,
  216. int64_t K,
  217. int64_t ld_a,
  218. int64_t ld_b,
  219. int64_t ld_c,
  220. const bool add_C,
  221. const float* A,
  222. const float* B,
  223. float* C,
  224. bool is_vnni = false);
  225. TORCH_API void brgemm(
  226. int64_t M,
  227. int64_t N,
  228. int64_t K,
  229. int64_t ld_a,
  230. int64_t ld_b,
  231. int64_t ld_c,
  232. const bool add_C,
  233. const unsigned char* A,
  234. const unsigned char* B,
  235. int32_t* C,
  236. bool is_vnni = true);
  237. TORCH_API void brgemm(
  238. int64_t M,
  239. int64_t N,
  240. int64_t K,
  241. int64_t ld_a,
  242. int64_t ld_b,
  243. int64_t ld_c,
  244. const bool add_C,
  245. const unsigned char* A,
  246. const signed char* B,
  247. int32_t* C,
  248. bool is_vnni = true);
  249. TORCH_API void brgemm(
  250. int64_t M,
  251. int64_t N,
  252. int64_t K,
  253. int64_t ld_a,
  254. int64_t ld_b,
  255. int64_t ld_c,
  256. const bool add_C,
  257. const signed char* A,
  258. const signed char* B,
  259. int32_t* C,
  260. bool is_vnni = true);
  261. // Release brgemm hardware context
  262. TORCH_API void brgemm_release(bool is_vnni = true);
  263. // Pack B matrix to get better performance if needed
  264. TORCH_API void pack(
  265. int64_t K,
  266. int64_t N,
  267. int64_t ld_in,
  268. int64_t ld_out,
  269. ScalarType dt_in,
  270. ScalarType dt_out,
  271. const void* in,
  272. void* out);
  273. // Whether pack is supported in the platform.
  274. TORCH_API bool could_pack(ScalarType dt_in);
  275. } // namespace at::native::cpublas
  276. #else
  277. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  278. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)