BlasBackend.h 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/util/Exception.h>
  4. #include <ostream>
  5. #include <string>
  6. namespace at {
  7. enum class BlasBackend : int8_t { Default, Cublas, Cublaslt, Ck };
  8. inline std::string BlasBackendToString(at::BlasBackend backend) {
  9. switch (backend) {
  10. case BlasBackend::Default:
  11. return "at::BlasBackend::Default";
  12. case BlasBackend::Cublas:
  13. return "at::BlasBackend::Cublas";
  14. case BlasBackend::Cublaslt:
  15. return "at::BlasBackend::Cublaslt";
  16. case BlasBackend::Ck:
  17. return "at::BlasBackend::Ck";
  18. default:
  19. TORCH_CHECK(false, "Unknown blas backend");
  20. }
  21. }
  22. inline std::ostream& operator<<(std::ostream& stream, at::BlasBackend backend) {
  23. return stream << BlasBackendToString(backend);
  24. }
  25. namespace blas {
  26. enum class ScalingType : std::uint8_t {
  27. TensorWise, // fp32 scales
  28. RowWise, // fp32 scales
  29. BlockWise1x16, // fp8_e4m3fn scales
  30. BlockWise1x32, // fp8_e8m0fnu scales
  31. BlockWise1x128, // fp32 scales
  32. BlockWise128x128, // fp32 scales
  33. };
  34. enum class SwizzleType : std::uint8_t { NO_SWIZZLE = 0, SWIZZLE_32_4_4 = 1 };
  35. } // namespace blas
  36. } // namespace at
  37. #else
  38. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  39. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)