compile.cpp 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. // SPDX-License-Identifier: MIT
  2. // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
  3. /* clang-format off */
  4. #include <stdio.h>
  5. #include <stdint.h>
  6. #include <inttypes.h>
  7. #include <string.h>
  8. #include <hip/hip_runtime.h>
  9. // helpers to check for hip errors
  10. #define HIP_CHECK(ans) {{\
  11. gpuAssert((ans), __FILE__, __LINE__);\
  12. }}\
  13. static inline void gpuAssert(hipError_t code, const char *file, int line) {{
  14. if (code != hipSuccess) {{
  15. const char *prefix = "Triton Error [HIP]: ";
  16. const char *str;
  17. hipDrvGetErrorString(code, &str);
  18. char err[1024] = {{0}};
  19. strcat(err, prefix);
  20. strcat(err, str);
  21. printf("%s\\n", err);
  22. exit(code);
  23. }}
  24. }}
  25. // globals
  26. #define HSACO_NAME {kernel_name}_hsaco
  27. hipModule_t {kernel_name}_mod = nullptr;
  28. hipFunction_t {kernel_name}_func = nullptr;
  29. unsigned char HSACO_NAME[{bin_size}] = {{ {bin_data} }};
  30. void unload_{kernel_name}(void) {{
  31. HIP_CHECK(hipModuleUnload({kernel_name}_mod));
  32. }}
  33. void load_{kernel_name}() {{
  34. int dev = 0;
  35. void *bin = (void *)&HSACO_NAME;
  36. int shared = {shared};
  37. HIP_CHECK(hipModuleLoadData(&{kernel_name}_mod, bin));
  38. HIP_CHECK(hipModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}"));
  39. }}
  40. /*
  41. {kernel_docstring}
  42. */
  43. hipError_t {kernel_name}(hipStream_t stream, {signature}) {{
  44. if ({kernel_name}_func == nullptr)
  45. load_{kernel_name}();
  46. unsigned int gX = {gridX};
  47. unsigned int gY = {gridY};
  48. unsigned int gZ = {gridZ};
  49. hipDeviceptr_t global_scratch = 0;
  50. hipDeviceptr_t profile_scratch = 0;
  51. void *args[{num_args}] = {{ {arg_pointers} }};
  52. // TODO: shared memory
  53. if(gX * gY * gZ > 0)
  54. return hipModuleLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * {warp_size}, 1, 1, {shared}, stream, args, nullptr);
  55. else
  56. return hipErrorInvalidValue;
  57. }}