GroupedMMUtils.h 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/Tensor.h>
  4. #include <ATen/TensorUtils.h>
  5. #ifndef AT_PER_OPERATOR_HEADERS
  6. #include <ATen/CPUFunctions.h>
  7. #include <ATen/Functions.h>
  8. #include <ATen/NativeFunctions.h>
  9. #else
  10. #include <ATen/ops/bmm.h>
  11. #include <ATen/ops/empty.h>
  12. #include <ATen/ops/empty_strided.h>
  13. #include <ATen/ops/mm.h>
  14. #include <ATen/ops/zeros.h>
  15. #endif
  16. namespace at::native {
  17. inline bool check_valid_strides_and_return_transposed(const Tensor& mat) {
  18. IntArrayRef tensor_strides = mat.strides();
  19. IntArrayRef tensor_sizes = mat.sizes();
  20. int end_dim = mat.dim() - 1;
  21. int alignment = 16 / mat.element_size();
  22. bool is_cpu = mat.device().is_cpu();
  23. TORCH_CHECK(is_cpu || uint64_t(mat.data_ptr()) % 16 == 0, "expected data_ptr to be aligned to 16 bytes");
  24. if ((tensor_strides[end_dim - 1] == 1) && (tensor_strides[end_dim] >= std::max<int64_t>(1, tensor_sizes[end_dim - 1]))) {
  25. TORCH_CHECK(tensor_strides[end_dim] % alignment == 0, "strides should be multiple of 16 bytes");
  26. return true;
  27. } else if ((tensor_strides[end_dim] == 1) && (tensor_strides[end_dim - 1] >= std::max<int64_t>(1, tensor_sizes[end_dim]))) {
  28. TORCH_CHECK(tensor_strides[end_dim - 1] % alignment == 0, "strides should be multiple of 16 bytes");
  29. return false;
  30. } else {
  31. TORCH_CHECK(false, "Invalid strides/sizes, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes");
  32. }
  33. }
  34. inline at::Tensor create_grouped_gemm_output_tensor(const Tensor& mat_a,
  35. const Tensor& mat_b,
  36. const std::optional<at::Tensor>& offs,
  37. c10::ScalarType out_dtype
  38. ) {
  39. c10::SmallVector<int64_t, 3> out_size;
  40. const bool a_is_2d = mat_a.dim() == 2;
  41. const bool b_is_2d = mat_b.dim() == 2;
  42. if (a_is_2d) {
  43. if (b_is_2d) {
  44. out_size = {offs->size(0), mat_a.size(0), mat_b.size(1)};
  45. } else {
  46. TORCH_CHECK(offs->size(0) == mat_b.size(0), "matrix batch sizes have to match");
  47. out_size = {mat_a.size(0), mat_b.size(-1)};
  48. }
  49. } else {
  50. if (b_is_2d) {
  51. // this case is not actually encountered for MoE gemms
  52. TORCH_CHECK(offs->size(0) == mat_a.size(0), "matrix batch sizes have to match");
  53. out_size = {mat_a.size(1), mat_b.size(1)};
  54. } else { // regular bmm
  55. TORCH_CHECK(mat_a.size(0) == mat_b.size(0), "batched dimension has to match");
  56. out_size = {mat_a.size(0), mat_a.size(1), mat_b.size(-1)};
  57. }
  58. }
  59. #ifndef USE_ROCM
  60. // For TMA transfers, strides of output tensor have to be either
  61. // 1, or aligned to 16 bytes.
  62. const auto last_dim = out_size.size() - 1;
  63. const auto alignment = 16 / c10::elementSize(out_dtype);
  64. const int64_t size_padded = (out_size[last_dim] + alignment - 1) / alignment * alignment;
  65. std::vector<int64_t> out_stride;
  66. if (a_is_2d != b_is_2d) {
  67. out_stride = {size_padded, 1};
  68. } else {
  69. out_stride = {out_size[1] * size_padded, size_padded, 1};
  70. }
  71. return at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype));
  72. #else
  73. // For ROCm 2D-2D case (output is 3D), zero-initialize to handle K=0 or small K
  74. // groups correctly. When K=0, the mathematically correct result is zeros,
  75. // but CK kernel may not write to the output region.
  76. if (a_is_2d && b_is_2d) {
  77. return at::zeros(out_size, mat_a.options().dtype(out_dtype));
  78. }
  79. return at::empty(out_size, mat_a.options().dtype(out_dtype));
  80. #endif
  81. }
  82. inline void _grouped_mm_validate_inputs(const Tensor& mat_a, const Tensor& mat_b,
  83. const std::optional<at::Tensor>& offs,
  84. const std::optional<at::Tensor>& bias,
  85. std::optional<c10::ScalarType> out_dtype) {
  86. TORCH_CHECK((mat_a.dtype() == at::kBFloat16) || (mat_a.dtype() == at::kFloat) || (mat_a.dtype() == at::kHalf), "Expected mat_a to be Float32, BFloat16 or Float16 matrix, got ", mat_a.scalar_type());
  87. TORCH_CHECK((mat_b.dtype() == at::kBFloat16) || (mat_b.dtype() == at::kFloat) || (mat_b.dtype() == at::kHalf), "Expected mat_b to be Float32, BFloat16 or Float16 matrix, got ", mat_b.scalar_type());
  88. TORCH_CHECK(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
  89. TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
  90. const bool a_is_2d = mat_a.dim() == 2;
  91. const bool b_is_2d = mat_b.dim() == 2;
  92. if (!a_is_2d || !b_is_2d) {
  93. TORCH_CHECK(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
  94. }
  95. // check that the strides are valid, the fn will throw an error if not
  96. check_valid_strides_and_return_transposed(mat_a);
  97. check_valid_strides_and_return_transposed(mat_b);
  98. TORCH_CHECK(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix, or no offset if both matrices are 3d");
  99. if (offs.has_value()) {
  100. TORCH_CHECK(offs->dim() == 1, "offs has to be 1D");
  101. TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32");
  102. }
  103. TORCH_CHECK(!bias.has_value(), "Bias not supported yet");
  104. }
  105. inline c10::ScalarType _resolve_grouped_mm_out_dtype(const Tensor& mat_a, [[maybe_unused]] const Tensor& mat_b,
  106. std::optional<c10::ScalarType> out_dtype) {
  107. const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
  108. // TODO(future PR): enable float32 output dtype for bfloat16 and float16 inputs
  109. TORCH_CHECK(out_dtype_ == mat_a.dtype(), "Grouped gemm output dtype must match `mat_a` dtype");
  110. return out_dtype_;
  111. }
  112. inline void _grouped_mm_fallback(const Tensor& mat_a, const Tensor& mat_b,
  113. const std::optional<at::Tensor>& offs,
  114. const std::optional<at::Tensor>& bias,
  115. std::optional<c10::ScalarType> out_dtype,
  116. Tensor out) {
  117. LOG(INFO) << "fallback path for `torch._grouped_mm`, performance may not be optimal";
  118. const bool a_is_2d = mat_a.dim() == 2;
  119. const bool b_is_2d = mat_b.dim() == 2;
  120. if (a_is_2d && !b_is_2d) {
  121. // 2d x 3d with offsets
  122. int group_start_idx = 0;
  123. auto offs_cpu = offs.value().cpu();
  124. for (int group_idx = 0; group_idx < offs_cpu.size(0); group_idx++) {
  125. int group_end_idx = offs_cpu[group_idx].item<int>();
  126. auto mat_a_slice = mat_a.slice(0, group_start_idx, group_end_idx);
  127. auto out_slice = out.slice(0, group_start_idx, group_end_idx);
  128. at::mm_out(out_slice, mat_a_slice, mat_b[group_idx]);
  129. group_start_idx = group_end_idx;
  130. }
  131. } else if (!a_is_2d && b_is_2d) {
  132. // 3d x 2d with offsets
  133. int group_start_idx = 0;
  134. auto offs_cpu = offs.value().cpu();
  135. for (int group_idx = 0; group_idx < offs_cpu.size(0); group_idx++) {
  136. int group_end_idx = offs_cpu[group_idx].item<int>();
  137. auto mat_b_slice = mat_b.slice(1, group_start_idx, group_end_idx);
  138. auto out_slice = out.slice(1, group_start_idx, group_end_idx);
  139. at::mm_out(out_slice, mat_a[group_idx], mat_b_slice);
  140. group_start_idx = group_end_idx;
  141. }
  142. } else if (a_is_2d && b_is_2d) {
  143. // 2d x 2d with offsets
  144. int group_start_idx = 0;
  145. auto offs_cpu = offs.value().cpu();
  146. for (int group_idx = 0; group_idx < offs_cpu.size(0); group_idx++) {
  147. int group_end_idx = offs_cpu[group_idx].item<int>();
  148. auto mat_a_slice = mat_a.slice(1, group_start_idx, group_end_idx);
  149. auto mat_b_slice = mat_b.slice(0, group_start_idx, group_end_idx);
  150. auto out_slice = out[group_idx];
  151. at::mm_out(out_slice, mat_a_slice, mat_b_slice);
  152. group_start_idx = group_end_idx;
  153. }
  154. } else {
  155. // 3d x 3d without offsets - regular bmm
  156. at::bmm_out(out, mat_a, mat_b);
  157. }
  158. }
  159. } // namespace at::native
  160. #else
  161. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  162. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)