CUDAScaledBlas.h 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #include <cstdint>
  3. #include <c10/util/typeid.h>
  4. #include <c10/util/Exception.h>
  5. #include <c10/util/SmallVector.h>
  6. #include <c10/core/Scalar.h>
  7. #include <c10/core/ScalarType.h>
  8. #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
  9. #include <ATen/core/Tensor.h>
  10. #include <ATen/core/NamedTensor.h>
  11. #include <ATen/Dispatch.h>
  12. #include <ATen/ExpandUtils.h>
  13. #include <ATen/OpMathType.h>
  14. #include <ATen/TensorUtils.h>
  15. #include <ATen/cuda/CUDABlas.h>
  16. #include <ATen/cuda/tunable/Tunable.h>
  17. #include <ATen/cuda/tunable/TunableGemm.h>
  18. #include <ATen/native/Resize.h>
  19. #include <c10/util/MaybeOwned.h>
  20. #include <ATen/native/GroupedMMUtils.h>
  21. #include <ATen/native/cuda/RowwiseScaledMM.h>
  22. #include <ATen/native/cuda/ScaledGroupMM.h>
  23. #include <ATen/native/cuda/GroupMM.h>
  24. #include <ATen/ceil_div.h>
  25. #ifdef USE_MSLK
  26. #include <mslk/gemm/gemm_torch.h>
  27. #endif
  28. #ifndef AT_PER_OPERATOR_HEADERS
  29. #include <ATen/Functions.h>
  30. #include <ATen/NativeFunctions.h>
  31. #else
  32. #include <ATen/ops/_addmm_activation_native.h>
  33. #include <ATen/ops/_efficientzerotensor.h>
  34. #include <ATen/ops/_scaled_mm_native.h>
  35. #include <ATen/ops/_unsafe_view_native.h>
  36. #include <ATen/ops/abs.h>
  37. #include <ATen/ops/addmm_native.h>
  38. #include <ATen/ops/addmv_native.h>
  39. #include <ATen/ops/baddbmm_native.h>
  40. #include <ATen/ops/bmm_native.h>
  41. #include <ATen/ops/copy_native.h>
  42. #include <ATen/ops/dot_native.h>
  43. #include <ATen/ops/empty.h>
  44. #include <ATen/ops/empty_strided.h>
  45. #include <ATen/ops/gelu.h>
  46. #include <ATen/ops/max.h>
  47. #include <ATen/ops/mm_native.h>
  48. #include <ATen/ops/mul.h>
  49. #include <ATen/ops/relu.h>
  50. #include <ATen/ops/ones.h>
  51. #include <ATen/ops/scalar_tensor_native.h>
  52. #include <ATen/ops/vdot_native.h>
  53. #endif
  54. using at::blas::ScalingType;
  55. using at::blas::SwizzleType;
  56. namespace at::cuda::scaled {
  57. static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) {
  58. #ifdef USE_ROCM
  59. static const std::vector<std::string> archs = {
  60. "gfx942",
  61. #if ROCM_VERSION >= 60300
  62. "gfx1200", "gfx1201",
  63. #endif
  64. #if ROCM_VERSION >= 60500
  65. "gfx950"
  66. #endif
  67. };
  68. return at::detail::getCUDAHooks().isGPUArch(archs);
  69. #else
  70. auto dprops = at::cuda::getCurrentDeviceProperties();
  71. if (sm90_only || sm100_only) {
  72. return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10);
  73. } else {
  74. return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
  75. }
  76. #endif
  77. }
  78. #ifdef USE_ROCM
  79. static bool _scaled_mm_is_fnuz() {
  80. return at::detail::getCUDAHooks().isGPUArch({"gfx942"});
  81. }
  82. #endif
  83. /**
  84. * Track concrete implementations available
  85. */
  86. enum class ScaledGemmImplementation {
  87. NONE = 0,
  88. TENSORWISE_TENSORWISE = 1,
  89. ROWWISE_ROWWISE = 2,
  90. BLOCK_128x128_1x128 = 3,
  91. BLOCK_1x128_128x128 = 4,
  92. BLOCK_1x128_1x128 = 5,
  93. MXFP8_MXFP8 = 6,
  94. NVFP4_NVFP4 = 7,
  95. NVFP4_NVFP4_SINGLE_SCALE = 8,
  96. MXFP4_MXFP4 = 9,
  97. };
  98. /**
  99. * Convert passed int (enum) from python back into a
  100. * strictly-typed enum
  101. */
  102. template <class EnumType, class ArrayType>
  103. std::vector<EnumType> convert_int_to_enum(ArrayType& v) {
  104. std::vector<EnumType> converted;
  105. converted.reserve(v.size());
  106. for (auto vi : v) {
  107. converted.push_back(static_cast<EnumType>(vi));
  108. }
  109. return converted;
  110. }
  111. bool check_tensorwise_recipe(c10::ScalarType,
  112. std::vector<ScalingType>&,
  113. ArrayRef<Tensor>&,
  114. c10::ScalarType,
  115. std::vector<ScalingType>&,
  116. ArrayRef<Tensor>&);
  117. bool check_rowwise_recipe(c10::ScalarType,
  118. std::vector<ScalingType>&,
  119. ArrayRef<Tensor>&,
  120. c10::ScalarType,
  121. std::vector<ScalingType>&,
  122. ArrayRef<Tensor>&);
  123. bool check_nvfp4_recipe(c10::ScalarType,
  124. std::vector<ScalingType>&,
  125. ArrayRef<Tensor>&,
  126. c10::ScalarType,
  127. std::vector<ScalingType>&,
  128. ArrayRef<Tensor>&);
  129. bool check_nvfp4_recipe_single_scale
  130. (c10::ScalarType,
  131. std::vector<ScalingType>&,
  132. ArrayRef<Tensor>&,
  133. c10::ScalarType,
  134. std::vector<ScalingType>&,
  135. ArrayRef<Tensor>&);
  136. bool check_deepseek_recipe(ScalingType,
  137. ScalingType,
  138. c10::ScalarType,
  139. std::vector<ScalingType>&,
  140. ArrayRef<Tensor>&,
  141. c10::ScalarType,
  142. std::vector<ScalingType>&,
  143. ArrayRef<Tensor>&);
  144. bool check_mxfp8_recipe(c10::ScalarType,
  145. std::vector<ScalingType>&,
  146. ArrayRef<Tensor>&,
  147. c10::ScalarType,
  148. std::vector<ScalingType>&,
  149. ArrayRef<Tensor>&);
  150. bool check_mxfp4_recipe(c10::ScalarType,
  151. std::vector<ScalingType>&,
  152. ArrayRef<Tensor>&,
  153. c10::ScalarType,
  154. std::vector<ScalingType>&,
  155. ArrayRef<Tensor>&);
  156. } // namespace at::native::cuda::blas::scaled
  157. #else
  158. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  159. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)