| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- #pragma once
- #include <ATen/OpMathType.h>
- #include <ATen/native/DispatchStub.h>
- #include <ATen/native/TransposeType.h>
- #include <c10/util/complex.h>
- #include <c10/core/ScalarType.h>
- #include <c10/core/Scalar.h>
- namespace at::native::cpublas {
- namespace internal {
- void normalize_last_dims(
- TransposeType transa, TransposeType transb,
- int64_t m, int64_t n, int64_t k,
- int64_t *lda, int64_t *ldb, int64_t *ldc);
- } // namespace internal
- using gemm_fn = void(*)(
- at::ScalarType type,
- TransposeType transa, TransposeType transb,
- int64_t m, int64_t n, int64_t k,
- const Scalar& alpha,
- const void *a, int64_t lda,
- const void *b, int64_t ldb,
- const Scalar& beta,
- void *c, int64_t ldc);
- DECLARE_DISPATCH(gemm_fn, gemm_stub)
- using gemm_no_downcast_fn = void(*)(
- at::ScalarType type,
- TransposeType transa, TransposeType transb,
- int64_t m, int64_t n, int64_t k,
- const Scalar& alpha,
- const void *a, int64_t lda,
- const void *b, int64_t ldb,
- const Scalar& beta,
- void *c, int64_t ldc);
- DECLARE_DISPATCH(gemm_no_downcast_fn, gemm_no_downcast_stub)
- template <typename scalar_t>
- void gemm(
- TransposeType transa, TransposeType transb,
- int64_t m, int64_t n, int64_t k,
- at::opmath_type<scalar_t> alpha,
- const scalar_t *a, int64_t lda,
- const scalar_t *b, int64_t ldb,
- at::opmath_type<scalar_t> beta,
- scalar_t *c, int64_t ldc) {
- internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
- gemm_stub(
- kCPU, c10::CppTypeToScalarType<scalar_t>::value,
- transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
- }
- void gemm(
- TransposeType transa, TransposeType transb,
- int64_t m, int64_t n, int64_t k,
- double alpha,
- const double *a, int64_t lda,
- const double *b, int64_t ldb,
- double beta,
- double *c, int64_t ldc);
- void gemm(
- TransposeType transa, TransposeType transb,
- int64_t m, int64_t n, int64_t k,
- float alpha,
- const float *a, int64_t lda,
- const float *b, int64_t ldb,
- float beta,
- float *c, int64_t ldc);
- void gemm(
- TransposeType transa, TransposeType transb,
- int64_t m, int64_t n, int64_t k,
- float alpha,
- const at::BFloat16 *a, int64_t lda,
- const at::BFloat16 *b, int64_t ldb,
- float beta,
- at::BFloat16 *c, int64_t ldc);
- void gemm(
- TransposeType transa, TransposeType transb,
- int64_t m, int64_t n, int64_t k,
- const float alpha,
- const at::BFloat16 *a, int64_t lda,
- const at::BFloat16 *b, int64_t ldb,
- const float beta,
- float *c, int64_t ldc);
- void gemm(
- TransposeType transa, TransposeType transb,
- int64_t m, int64_t n, int64_t k,
- float alpha,
- const at::Half *a, int64_t lda,
- const at::Half *b, int64_t ldb,
- float beta,
- at::Half *c, int64_t ldc);
- void gemm(
- TransposeType transa, TransposeType transb,
- int64_t m, int64_t n, int64_t k,
- const float alpha,
- const at::Half *a, int64_t lda,
- const at::Half *b, int64_t ldb,
- const float beta,
- float *c, int64_t ldc);
- void gemm(
- TransposeType transa, TransposeType transb,
- int64_t m, int64_t n, int64_t k,
- c10::complex<double> alpha,
- const c10::complex<double> *a, int64_t lda,
- const c10::complex<double> *b, int64_t ldb,
- c10::complex<double> beta,
- c10::complex<double> *c, int64_t ldc);
- void gemm(
- TransposeType transa, TransposeType transb,
- int64_t m, int64_t n, int64_t k,
- c10::complex<float> alpha,
- const c10::complex<float> *a, int64_t lda,
- const c10::complex<float> *b, int64_t ldb,
- c10::complex<float> beta,
- c10::complex<float> *c, int64_t ldc);
- void gemm(
- TransposeType transa, TransposeType transb,
- int64_t m, int64_t n, int64_t k,
- int64_t alpha,
- const int64_t *a, int64_t lda,
- const int64_t *b, int64_t ldb,
- int64_t beta,
- int64_t *c, int64_t ldc);
- template <typename scalar_t>
- void gemm_batched(
- TransposeType transa, TransposeType transb,
- int64_t batch_size, int64_t m, int64_t n, int64_t k,
- scalar_t alpha,
- const scalar_t * const *a, int64_t lda,
- const scalar_t * const *b, int64_t ldb,
- const scalar_t beta,
- scalar_t * const *c, int64_t ldc);
- template <typename scalar_t>
- void gemm_batched_with_stride(
- TransposeType transa, TransposeType transb,
- int64_t batch_size, int64_t m, int64_t n, int64_t k,
- scalar_t alpha,
- const scalar_t *a, int64_t lda, int64_t batch_stride_a,
- const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
- scalar_t beta,
- scalar_t *c, int64_t ldc, int64_t batch_stride_c);
- using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy);
- DECLARE_DISPATCH(axpy_fn, axpy_stub)
- template<typename scalar_t>
- void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){
- if(n == 1)
- {
- incx = 1;
- incy = 1;
- }
- axpy_stub(
- kCPU, c10::CppTypeToScalarType<scalar_t>::value,
- n, a, x, incx, y, incy);
- }
- void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy);
- void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy);
- void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
- void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
- using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy);
- DECLARE_DISPATCH(copy_fn, copy_stub)
- template<typename scalar_t>
- void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) {
- if(n == 1)
- {
- incx = 1;
- incy = 1;
- }
- copy_stub(
- kCPU, c10::CppTypeToScalarType<scalar_t>::value,
- n, x, incx, y, incy);
- }
- void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy);
- void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
- void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
- void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
- // Batch-reduce GEMM
- // Operates by the following formula:
- // C = SUM(A[i] x B[i]) + C if add_C is true, i = 0 to batch size
- // A Base pointer to a tensor A.
- // B Base pointer to a tensor B.
- // C Pointer to a tensor C (accumulation buffer).
- // Note only batch size 1 is used currently
- // Define macros for available brgemm APIs
- // so that callers can determine which APIs are available
- #define CPUBLAS_BRGEMM_F16F16F32 // half * half -> float
- #define CPUBLAS_BRGEMM_BF16BF16F32 // bfloat16 * bfloat16 -> float
- #define CPUBLAS_BRGEMM_F32F32F32 // float * float -> float
- #define CPUBLAS_BRGEMM_U8U8I32 // unsigned char * unsigned char -> int32
- #define CPUBLAS_BRGEMM_U8I8I32 // unsigned char * signed char -> int32
- #define CPUBLAS_BRGEMM_I8I8I32 // signed char * signed char -> int32
- TORCH_API void brgemm(
- int64_t M,
- int64_t N,
- int64_t K,
- int64_t ld_a,
- int64_t ld_b,
- int64_t ld_c,
- const bool add_C,
- const at::Half* A,
- const at::Half* B,
- float* C,
- bool is_vnni = true);
- TORCH_API void brgemm(
- int64_t M,
- int64_t N,
- int64_t K,
- int64_t ld_a,
- int64_t ld_b,
- int64_t ld_c,
- const bool add_C,
- const at::BFloat16* A,
- const at::BFloat16* B,
- float* C,
- bool is_vnni = true);
- TORCH_API void brgemm(
- int64_t M,
- int64_t N,
- int64_t K,
- int64_t ld_a,
- int64_t ld_b,
- int64_t ld_c,
- const bool add_C,
- const float* A,
- const float* B,
- float* C,
- bool is_vnni = false);
- TORCH_API void brgemm(
- int64_t M,
- int64_t N,
- int64_t K,
- int64_t ld_a,
- int64_t ld_b,
- int64_t ld_c,
- const bool add_C,
- const unsigned char* A,
- const unsigned char* B,
- int32_t* C,
- bool is_vnni = true);
- TORCH_API void brgemm(
- int64_t M,
- int64_t N,
- int64_t K,
- int64_t ld_a,
- int64_t ld_b,
- int64_t ld_c,
- const bool add_C,
- const unsigned char* A,
- const signed char* B,
- int32_t* C,
- bool is_vnni = true);
- TORCH_API void brgemm(
- int64_t M,
- int64_t N,
- int64_t K,
- int64_t ld_a,
- int64_t ld_b,
- int64_t ld_c,
- const bool add_C,
- const signed char* A,
- const signed char* B,
- int32_t* C,
- bool is_vnni = true);
- // Release brgemm hardware context
- TORCH_API void brgemm_release(bool is_vnni = true);
- // Pack B matrix to get better performance if needed
- TORCH_API void pack(
- int64_t K,
- int64_t N,
- int64_t ld_in,
- int64_t ld_out,
- ScalarType dt_in,
- ScalarType dt_out,
- const void* in,
- void* out);
- // Whether pack is supported in the platform.
- TORCH_API bool could_pack(ScalarType dt_in);
- } // namespace at::native::cpublas
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|