driver.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978
  1. import functools
  2. import os
  3. import platform
  4. import subprocess
  5. import re
  6. import triton
  7. from pathlib import Path
  8. from triton import knobs
  9. from triton.backends.compiler import GPUTarget
  10. from triton.backends.driver import GPUDriver
  11. from triton.runtime import _allocation
  12. from triton.runtime.build import compile_module_from_src
  13. dirname = os.path.dirname(os.path.realpath(__file__))
  14. include_dirs = [os.path.join(dirname, "include")]
  15. PyTDMDescriptor = None
  16. def _is_windows():
  17. return platform.system() == 'Windows'
  18. def _get_rocm_sdk_root():
  19. """Get ROCm SDK root path using rocm-sdk command or environment variables."""
  20. # Try rocm-sdk path --root first (for Windows ROCm SDK)
  21. try:
  22. result = subprocess.check_output(["rocm-sdk", "path", "--root"], stderr=subprocess.DEVNULL)
  23. root = result.decode().strip()
  24. if root and os.path.isdir(root):
  25. return root
  26. except (subprocess.CalledProcessError, FileNotFoundError):
  27. pass
  28. # Fall back to environment variables
  29. for env_var in ["ROCM_HOME", "HIP_PATH", "ROCM_PATH"]:
  30. path = os.environ.get(env_var, "")
  31. if path and os.path.isdir(path):
  32. return path
  33. return None
  34. def _get_hip_library_from_rocm_sdk():
  35. """Get the amdhip64 library path using rocm_sdk.find_libraries."""
  36. try:
  37. import rocm_sdk
  38. paths = rocm_sdk.find_libraries("amdhip64")
  39. if paths:
  40. return str(paths[0])
  41. except (ImportError, ModuleNotFoundError, FileNotFoundError):
  42. pass
  43. return None
  44. # Add HIP runtime headers from ROCm SDK if available
  45. _rocm_root = _get_rocm_sdk_root()
  46. if _rocm_root and os.path.isdir(os.path.join(_rocm_root, "include")):
  47. include_dirs.append(os.path.join(_rocm_root, "include"))
  48. def _find_already_mmapped_dylib_on_linux(lib_name):
  49. if platform.system() != 'Linux':
  50. return None
  51. # Use dl_iterate_phdr to walk through the list of shared libraries at runtime.
  52. # See https://www.man7.org/linux/man-pages/man3/dl_iterate_phdr.3.html for details.
  53. import ctypes
  54. from ctypes import c_char, c_int, c_size_t, c_void_p, c_char_p, POINTER
  55. class DlPhdrInfo(ctypes.Structure):
  56. _fields_ = [
  57. ('dlpi_addr', c_void_p),
  58. ('dlpi_name', c_char_p),
  59. # We don't care about the remaining fields.
  60. ]
  61. # callback_t must use POINTER(c_char) to avoid copying.
  62. callback_t = ctypes.CFUNCTYPE(c_int, POINTER(DlPhdrInfo), POINTER(c_size_t), POINTER(c_char))
  63. # Load libc and get the dl_iterate_phdr symbol.
  64. try:
  65. dl_iterate_phdr = ctypes.CDLL('libc.so.6').dl_iterate_phdr
  66. except Exception:
  67. return None
  68. # argtypes must use c_char_p to accept create_string_buffer.
  69. dl_iterate_phdr.argtypes = [callback_t, c_char_p]
  70. dl_iterate_phdr.restype = c_int
  71. max_path_length = 4096
  72. path = ctypes.create_string_buffer(max_path_length + 1)
  73. # Define callback to get the loaded dylib path.
  74. def callback(info, size, data):
  75. dlpi_name = info.contents.dlpi_name
  76. p = Path(os.fsdecode(dlpi_name))
  77. if lib_name in p.name:
  78. # Found the dylib; get its path.
  79. ctypes.memmove(data, dlpi_name, min(max_path_length, len(dlpi_name)))
  80. return 1
  81. return 0
  82. if dl_iterate_phdr(callback_t(callback), path):
  83. return os.fsdecode(ctypes.string_at(path))
  84. return None
  85. @functools.lru_cache()
  86. def _get_path_to_hip_runtime_dylib():
  87. lib_name = "amdhip64.dll" if _is_windows() else "libamdhip64.so"
  88. # If we are told explicitly what HIP runtime dynamic library to use, obey that.
  89. if env_libhip_path := knobs.amd.libhip_path:
  90. if env_libhip_path.endswith(lib_name) and os.path.exists(env_libhip_path):
  91. return env_libhip_path
  92. raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}")
  93. # Try rocm_sdk.find_libraries first - this is the preferred method
  94. rocm_sdk_path = _get_hip_library_from_rocm_sdk()
  95. if rocm_sdk_path:
  96. return rocm_sdk_path
  97. # If the shared object is already mmapped to address space, use it (Linux only).
  98. if not _is_windows():
  99. mmapped_path = _find_already_mmapped_dylib_on_linux(lib_name)
  100. if mmapped_path:
  101. if os.path.exists(mmapped_path):
  102. return mmapped_path
  103. raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}")
  104. paths = []
  105. # Check backend
  106. local_lib = os.path.join(os.path.dirname(__file__), "lib", lib_name)
  107. if os.path.exists(local_lib):
  108. return local_lib
  109. paths.append(local_lib)
  110. import site
  111. # First search the HIP runtime dynamic library packaged with PyTorch. It's very likely
  112. # that we run Triton together with PyTorch. This makes sure we use the same dynamic
  113. # library to avoid version mismatch.
  114. site_packages = site.getsitepackages()
  115. user_site = site.getusersitepackages()
  116. if site.ENABLE_USER_SITE: # ENABLE_USER_SITE is initialized in getusersitepackages()
  117. site_packages = [user_site] + site_packages
  118. for path in site_packages:
  119. path = os.path.join(path, "torch", "lib", lib_name)
  120. if os.path.exists(path):
  121. return path
  122. paths.append(path)
  123. # Then try to see if developer provides a HIP runtime dynamic library using LD_LIBRARY_PATH (Linux) or PATH (Windows).
  124. if _is_windows():
  125. env_path = os.getenv("PATH", "")
  126. path_sep = ";"
  127. else:
  128. env_path = os.getenv("LD_LIBRARY_PATH", "")
  129. path_sep = ":"
  130. if env_path:
  131. for d in env_path.split(path_sep):
  132. f = os.path.join(d, lib_name)
  133. if os.path.exists(f):
  134. return f
  135. paths.append(f)
  136. # HIP_PATH should point to HIP SDK root if set
  137. env_hip_path = os.getenv("HIP_PATH")
  138. if env_hip_path:
  139. # On Windows, DLLs are in bin; on Linux, .so files are in lib
  140. lib_subdir = "bin" if _is_windows() else "lib"
  141. hip_lib_path = os.path.join(env_hip_path, lib_subdir, lib_name)
  142. if os.path.exists(hip_lib_path):
  143. return hip_lib_path
  144. paths.append(hip_lib_path)
  145. # Try rocm-sdk path --root (Windows ROCm SDK) or hipconfig --path (Linux)
  146. lib_subdir = "bin" if _is_windows() else "lib"
  147. try:
  148. if _is_windows():
  149. rocm_root = subprocess.check_output(["rocm-sdk", "path", "--root"],
  150. stderr=subprocess.DEVNULL).decode().strip()
  151. else:
  152. rocm_root = subprocess.check_output(["hipconfig", "--path"]).decode().strip()
  153. if rocm_root:
  154. rocm_lib_path = os.path.join(rocm_root, lib_subdir, lib_name)
  155. if os.path.exists(rocm_lib_path):
  156. return rocm_lib_path
  157. paths.append(rocm_lib_path)
  158. except (subprocess.CalledProcessError, FileNotFoundError):
  159. # rocm-sdk or hipconfig may not be available
  160. pass
  161. # ROCm lib dir based on env var
  162. env_rocm_path = os.getenv("ROCM_PATH") or os.getenv("ROCM_HOME")
  163. if env_rocm_path:
  164. rocm_lib_path = os.path.join(env_rocm_path, lib_subdir, lib_name)
  165. if os.path.exists(rocm_lib_path):
  166. return rocm_lib_path
  167. paths.append(rocm_lib_path)
  168. # Afterwards try to search the loader dynamic library resolution paths (Linux only).
  169. if not _is_windows():
  170. try:
  171. libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
  172. # each line looks like the following:
  173. # libamdhip64.so.6 (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so.6
  174. # libamdhip64.so (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so
  175. locs = [line.split()[-1] for line in libs.splitlines() if line.strip().endswith(lib_name)]
  176. for loc in locs:
  177. if os.path.exists(loc):
  178. return loc
  179. paths.append(loc)
  180. except (subprocess.CalledProcessError, FileNotFoundError):
  181. pass
  182. # As a last resort on Linux, guess if we have it in some common installation path.
  183. common_install_path = os.path.join('/opt/rocm/lib/', lib_name)
  184. if os.path.exists(common_install_path):
  185. return common_install_path
  186. paths.append(common_install_path)
  187. raise RuntimeError(f"cannot locate {lib_name} after attempted paths {paths}")
  188. class HIPUtils(object):
  189. def __new__(cls):
  190. if not hasattr(cls, "instance"):
  191. cls.instance = super(HIPUtils, cls).__new__(cls)
  192. return cls.instance
  193. def __init__(self):
  194. libhip_path = _get_path_to_hip_runtime_dylib()
  195. # Escape backslashes for C string embedding
  196. libhip_path_escaped = libhip_path.replace("\\", "\\\\")
  197. src = Path(os.path.join(dirname, "driver.c")).read_text()
  198. # Just do a simple search and replace here instead of templates or format strings.
  199. # This way we don't need to escape-quote C code curly brackets and we can replace
  200. # exactly once.
  201. src = src.replace('/*py_libhip_search_path*/', libhip_path_escaped, 1)
  202. mod = compile_module_from_src(src=src, name="hip_utils", include_dirs=include_dirs)
  203. self.load_binary = mod.load_binary
  204. self.get_device_properties = mod.get_device_properties
  205. self.create_tdm_descriptor = mod.create_tdm_descriptor
  206. global PyTDMDescriptor
  207. PyTDMDescriptor = mod.PyTDMDescriptor
  208. # -------------------- Launcher ----------------------------
  209. def ty_to_cpp(ty):
  210. if ty.startswith('*'):
  211. return "hipDeviceptr_t"
  212. if ty == "tensordesc":
  213. return "TDMDescriptor"
  214. return {
  215. "i1": "int8_t",
  216. "i8": "int8_t",
  217. "i16": "int16_t",
  218. "i32": "int32_t",
  219. "i64": "int64_t",
  220. "u1": "uint8_t",
  221. "u8": "uint8_t",
  222. "u16": "uint16_t",
  223. "u32": "uint32_t",
  224. "u64": "uint64_t",
  225. "fp16": "double",
  226. "bf16": "double",
  227. "fp32": "double",
  228. "f32": "double",
  229. "fp64": "double",
  230. }[ty]
  231. FLOAT_STORAGE_TYPE = {
  232. "fp16": "uint16_t",
  233. "bf16": "uint16_t",
  234. "fp32": "uint32_t",
  235. "f32": "uint32_t",
  236. "fp64": "uint64_t",
  237. }
  238. FLOAT_PACK_FUNCTION = {
  239. "fp16": "pack_fp16",
  240. "bf16": "pack_bf16",
  241. "fp32": "pack_fp32",
  242. "f32": "pack_fp32",
  243. "fp64": "pack_fp64",
  244. }
  245. _BASE_ARGS_FORMAT = "piiiKKOOOOO"
  246. def make_launcher(constants, signature, warp_size, tensordesc_meta):
  247. def _expand_signature(signature):
  248. output = []
  249. tensordesc_idx = 0
  250. for sig in signature:
  251. if isinstance(sig, str) and sig.startswith("tensordesc"):
  252. meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
  253. tensordesc_idx += 1
  254. match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", sig)
  255. dtype = match.group(1)
  256. shape = match.group(2)
  257. ndim = shape.count(",") + 1
  258. # If there is no descriptor's metadata, the descriptor has been decomposed to base pointer, shape and strides
  259. if meta is None:
  260. output.append("*" + dtype)
  261. for _ in range(2 * ndim):
  262. output.append("i64")
  263. output.append("i1")
  264. else:
  265. output.append("tensordesc")
  266. for _ in range(ndim):
  267. output.append("i32")
  268. for _ in range(ndim):
  269. output.append("i64")
  270. else:
  271. output.append(sig)
  272. return output
  273. def _serialize_signature(sig):
  274. if isinstance(sig, tuple):
  275. return ','.join(map(_serialize_signature, sig))
  276. return sig
  277. def _extracted_type(ty):
  278. if isinstance(ty, tuple):
  279. val = ','.join(map(_extracted_type, ty))
  280. return f"[{val}]"
  281. if ty.startswith("*") or ty.startswith("tensordesc"):
  282. return "PyObject*"
  283. if ty == "constexpr":
  284. return "PyObject*"
  285. return ty_to_cpp(ty)
  286. def format_of(ty):
  287. if isinstance(ty, tuple):
  288. val = ''.join(map(format_of, ty))
  289. return f"({val})"
  290. if ty.startswith("*") or ty.startswith("tensordesc"):
  291. return "O"
  292. if ty == "constexpr":
  293. return "O"
  294. return {
  295. "double": "d",
  296. "long": "l",
  297. "int8_t": "b",
  298. "int16_t": "h",
  299. "int32_t": "i",
  300. "int64_t": "L",
  301. "uint8_t": "B",
  302. "uint16_t": "H",
  303. "uint32_t": "I",
  304. "uint64_t": "K",
  305. }[ty_to_cpp(ty)]
  306. signature = {idx: s for idx, s in enumerate(_expand_signature(signature.values()))}
  307. args_format = ''.join([format_of(ty) for ty in signature.values()])
  308. format = _BASE_ARGS_FORMAT + args_format
  309. signature = ','.join(map(_serialize_signature, signature.values()))
  310. signature = list(filter(bool, signature.split(',')))
  311. signature = {i: s for i, s in enumerate(signature)}
  312. args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
  313. # Record the end of regular arguments;
  314. # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
  315. arg_decl_list = []
  316. for i, ty in signature.items():
  317. if ty == "constexpr":
  318. continue
  319. if ty in FLOAT_STORAGE_TYPE:
  320. arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
  321. else:
  322. arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
  323. arg_decls = ', '.join(arg_decl_list)
  324. internal_args_list = []
  325. for i, ty in signature.items():
  326. if ty.startswith("*"):
  327. internal_args_list.append(f"ptr_info{i}.dev_ptr")
  328. elif ty.startswith("tensordesc"):
  329. internal_args_list.append(f"*desc{i}")
  330. elif ty in FLOAT_STORAGE_TYPE:
  331. internal_args_list.append(f"_arg{i}_storage")
  332. elif ty != "constexpr":
  333. internal_args_list.append(f"_arg{i}")
  334. newline = '\n '
  335. ptr_decls = [
  336. f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;"
  337. for i, ty in signature.items()
  338. if ty.startswith("*")
  339. ]
  340. tensor_desc_decls = [
  341. f"TDMDescriptor* desc{i} = getTDMDescriptor(_arg{i}, {i});" for i, ty in signature.items()
  342. if ty.startswith("tensordesc")
  343. ]
  344. float_storage_decls = [
  345. f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
  346. for i, ty in signature.items()
  347. if ty in FLOAT_STORAGE_TYPE
  348. ]
  349. libhip_path = _get_path_to_hip_runtime_dylib()
  350. # Escape backslashes for C string embedding
  351. libhip_path = libhip_path.replace("\\", "\\\\")
  352. # generate glue code
  353. params = list(range(len(signature)))
  354. params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
  355. params.append("&global_scratch")
  356. params.append("&profile_scratch")
  357. # Platform-specific includes and dlopen/dlsym macros
  358. if _is_windows():
  359. platform_includes = """
  360. #define __HIP_PLATFORM_AMD__
  361. #include <hip/hip_runtime.h>
  362. #include <hip/hip_runtime_api.h>
  363. #include <Python.h>
  364. #include <windows.h>
  365. #include <stdbool.h>
  366. // Windows compatibility layer for dlopen/dlsym/dlclose
  367. static char _dlerror_buf[512];
  368. static inline void *dlopen(const char *filename, int flags) {
  369. (void)flags;
  370. HMODULE h = LoadLibraryA(filename);
  371. if (!h) {
  372. snprintf(_dlerror_buf, sizeof(_dlerror_buf), "LoadLibrary failed with error %lu", GetLastError());
  373. }
  374. return (void *)h;
  375. }
  376. static inline void *dlsym(void *handle, const char *symbol) {
  377. void *p = (void *)GetProcAddress((HMODULE)handle, symbol);
  378. if (!p) {
  379. snprintf(_dlerror_buf, sizeof(_dlerror_buf), "GetProcAddress failed for %s with error %lu", symbol, GetLastError());
  380. }
  381. return p;
  382. }
  383. static inline int dlclose(void *handle) { return FreeLibrary((HMODULE)handle) ? 0 : -1; }
  384. static inline const char *dlerror(void) { return _dlerror_buf[0] ? _dlerror_buf : NULL; }
  385. #define RTLD_LAZY 0
  386. #define RTLD_LOCAL 0
  387. #define RTLD_NOLOAD 0
  388. """
  389. else:
  390. platform_includes = """
  391. #define __HIP_PLATFORM_AMD__
  392. #include <hip/hip_runtime.h>
  393. #include <hip/hip_runtime_api.h>
  394. #include <Python.h>
  395. #include <dlfcn.h>
  396. #include <stdbool.h>
  397. """
  398. src = f"""{platform_includes}
  399. typedef struct {{
  400. uint32_t group0_0;
  401. uint32_t group0_1;
  402. uint32_t group0_2;
  403. uint32_t group0_3;
  404. uint32_t group1_0;
  405. uint32_t group1_1;
  406. uint32_t group1_2;
  407. uint32_t group1_3;
  408. uint32_t group1_4;
  409. uint32_t group1_5;
  410. uint32_t group1_6;
  411. uint32_t group1_7;
  412. }} TDMDescriptor;
  413. typedef struct {{
  414. PyObject_HEAD;
  415. TDMDescriptor desc;
  416. }} PyTDMDescriptorObject;
  417. // The list of paths to search for the HIP runtime library. The caller Python
  418. // code should substitute the search path placeholder.
  419. static const char *hipLibSearchPaths[] = {{"{libhip_path}"}};
  420. // The list of HIP dynamic library symbols and their signature we are interested
  421. // in this file.
  422. #define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \\
  423. FOR_EACH_STR_FN(hipGetLastError, true) \\
  424. FOR_EACH_STR_FN(hipGetErrorString, true, hipError_t hipError) \\
  425. FOR_EACH_ERR_FN(hipDrvLaunchKernelEx, false, \\
  426. const HIP_LAUNCH_CONFIG *config, \\
  427. hipFunction_t f, \\
  428. void **kernelParams, \\
  429. void **extra) \\
  430. FOR_EACH_ERR_FN(hipModuleLaunchKernel, true, hipFunction_t f, \\
  431. unsigned int gridDimX, unsigned int gridDimY, \\
  432. unsigned int gridDimZ, unsigned int blockDimX, \\
  433. unsigned int blockDimY, unsigned int blockDimZ, \\
  434. unsigned int sharedMemBytes, hipStream_t stream, \\
  435. void **kernelParams, void **extra) \\
  436. FOR_EACH_ERR_FN(hipModuleLaunchCooperativeKernel, true, hipFunction_t f, \\
  437. unsigned int gridDimX, unsigned int gridDimY, \\
  438. unsigned int gridDimZ, unsigned int blockDimX, \\
  439. unsigned int blockDimY, unsigned int blockDimZ, \\
  440. unsigned int sharedMemBytes, hipStream_t stream, \\
  441. void **kernelParams, void **extra) \\
  442. FOR_EACH_ERR_FN(hipPointerGetAttribute, true, void *data, \\
  443. hipPointer_attribute attribute, hipDeviceptr_t ptr)
  444. // The HIP symbol table for holding resolved dynamic library symbols.
  445. struct HIPSymbolTable {{
  446. #define DEFINE_EACH_ERR_FIELD(hipSymbolName, required, ...) \\
  447. hipError_t (*hipSymbolName)(__VA_ARGS__);
  448. #define DEFINE_EACH_STR_FIELD(hipSymbolName, required, ...) \\
  449. const char *(*hipSymbolName)(__VA_ARGS__);
  450. HIP_SYMBOL_LIST(DEFINE_EACH_ERR_FIELD, DEFINE_EACH_STR_FIELD)
  451. }};
  452. static struct HIPSymbolTable hipSymbolTable;
  453. bool initSymbolTable() {{
  454. void *lib = NULL;
  455. // Go through the list of search paths to open the first HIP driver library.
  456. int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
  457. for (int i = 0; i < n; ++i) {{
  458. void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
  459. if (handle) {{
  460. lib = handle;
  461. break;
  462. }}
  463. }}
  464. if (!lib) {{
  465. PyErr_SetString(PyExc_RuntimeError, "cannot open HIP runtime library");
  466. return false;
  467. }}
  468. typedef hipError_t (*hipGetProcAddress_fn)(
  469. const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
  470. hipDriverProcAddressQueryResult *symbolStatus);
  471. hipGetProcAddress_fn hipGetProcAddress;
  472. dlerror(); // Clear existing errors
  473. const char *error = NULL;
  474. *(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
  475. error = dlerror();
  476. if (error) {{
  477. PyErr_SetString(PyExc_RuntimeError,
  478. "cannot query 'hipGetProcAddress' from HIP runtime library");
  479. dlclose(lib);
  480. return false;
  481. }}
  482. // Resolve all symbols we are interested in.
  483. int hipVersion = HIP_VERSION;
  484. uint64_t hipFlags = 0;
  485. hipDriverProcAddressQueryResult symbolStatus;
  486. hipError_t status = hipSuccess;
  487. #define QUERY_EACH_FN(hipSymbolName, required, ...) \
  488. status = hipGetProcAddress(#hipSymbolName, \
  489. (void **)&hipSymbolTable.hipSymbolName, \
  490. hipVersion, hipFlags, &symbolStatus); \
  491. if (required && status != hipSuccess) {{ \
  492. PyErr_SetString(PyExc_RuntimeError, \
  493. "cannot get address for '" #hipSymbolName \
  494. "' from libamdhip64.so"); \
  495. dlclose(lib); \
  496. return false; \
  497. }}
  498. HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)
  499. return true;
  500. }}
  501. static inline void gpuAssert(hipError_t code, const char *file, int line)
  502. {{
  503. if (code != HIP_SUCCESS)
  504. {{
  505. const char* prefix = "Triton Error [HIP]: ";
  506. const char* str = hipSymbolTable.hipGetErrorString(code);
  507. char err[1024] = {{0}};
  508. snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
  509. PyErr_SetString(PyExc_RuntimeError, err);
  510. }}
  511. }}
  512. #define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
  513. static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
  514. if (gridX * gridY * gridZ == 0)
  515. return;
  516. hipDeviceptr_t global_scratch = 0;
  517. void *params[] = {{ {', '.join(params)} }};
  518. if(num_ctas > 1) {{
  519. if (!hipSymbolTable.hipDrvLaunchKernelEx) {{
  520. PyErr_SetString(PyExc_RuntimeError, "missing hipDrvLaunchKernelEx symbol; please update HIP runtime");
  521. return;
  522. }}
  523. hipLaunchAttribute attributes[2];
  524. // Attribute0: Cluster dimensions
  525. attributes[0].id = 4;
  526. int *cluster_dims = (int*)attributes[0].val.pad;
  527. cluster_dims[0] = num_ctas;
  528. cluster_dims[1] = 1;
  529. cluster_dims[2] = 1;
  530. // Attribute1: Cooperative launch
  531. attributes[1].id = hipLaunchAttributeCooperative;
  532. attributes[1].val.cooperative = launch_cooperative_grid;
  533. HIP_LAUNCH_CONFIG config = {{
  534. gridX * num_ctas, gridY, gridZ, // Grid size
  535. {warp_size} * num_warps, 1, 1, // Block size
  536. shared_memory, stream,
  537. attributes, 2 // Number of attributes
  538. }};
  539. HIP_CHECK(hipSymbolTable.hipDrvLaunchKernelEx(&config, function, params, 0));
  540. return;
  541. }}
  542. else if (launch_cooperative_grid) {{
  543. HIP_CHECK(hipSymbolTable.hipModuleLaunchCooperativeKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
  544. return;
  545. }}
  546. else {{
  547. HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
  548. }}
  549. }}
  550. typedef struct _DevicePtrInfo {{
  551. hipDeviceptr_t dev_ptr;
  552. bool valid;
  553. }} DevicePtrInfo;
  554. static PyObject* data_ptr_str = NULL;
  555. static PyObject* py_tdm_descriptor_type = NULL;
  556. static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
  557. DevicePtrInfo ptr_info;
  558. hipError_t status = hipSuccess;
  559. ptr_info.dev_ptr = 0;
  560. ptr_info.valid = true;
  561. if (PyLong_Check(obj)) {{
  562. ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
  563. return ptr_info;
  564. }}
  565. if (obj == Py_None) {{
  566. // valid nullptr
  567. return ptr_info;
  568. }}
  569. PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
  570. if (!ret) {{
  571. PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
  572. ptr_info.valid = false;
  573. goto cleanup;
  574. }}
  575. if (!PyLong_Check(ret)) {{
  576. PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
  577. ptr_info.valid = false;
  578. goto cleanup;
  579. }}
  580. ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
  581. if (!ptr_info.dev_ptr)
  582. goto cleanup;
  583. uint64_t dev_ptr;
  584. status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
  585. if (status == hipErrorInvalidValue) {{
  586. PyErr_Format(PyExc_ValueError,
  587. "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
  588. ptr_info.valid = false;
  589. // Clear and ignore HIP error
  590. (void)hipSymbolTable.hipGetLastError();
  591. }}
  592. ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
  593. cleanup:
  594. Py_DECREF(ret);
  595. return ptr_info;
  596. }}
  597. static inline TDMDescriptor* getTDMDescriptor(PyObject* obj, int idx) {{
  598. if (Py_TYPE(obj) != (PyTypeObject*)py_tdm_descriptor_type) {{
  599. PyErr_Format(PyExc_TypeError, "object must be of type PyTDMDescriptor, got %s", Py_TYPE(obj)->tp_name);
  600. return NULL;
  601. }}
  602. TDMDescriptor* desc = &((PyTDMDescriptorObject*)obj)->desc;
  603. return desc;
  604. }}
  605. static uint16_t pack_fp16(double f) {{
  606. uint16_t result;
  607. // from https://github.com/python/pythoncapi-compat/blob/5e317108f872c904eb726cb8d560dcadbdf88a72/pythoncapi_compat.h#L482-L492
  608. #if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
  609. _PyFloat_Pack2(f, (unsigned char*)&result, 1);
  610. #else
  611. PyFloat_Pack2(f, (char*)&result, 1);
  612. #endif
  613. return result;
  614. }}
  615. static uint16_t pack_bf16(double f) {{
  616. float f32 = (float)f;
  617. uint32_t u32 = *(uint32_t*)&f32;
  618. return (uint16_t)(u32 >> 16);
  619. }}
  620. static uint32_t pack_fp32(double f) {{
  621. float f32 = (float)f;
  622. return *(uint32_t*)&f32;
  623. }}
  624. static uint64_t pack_fp64(double f) {{
  625. return *(uint64_t*)&f;
  626. }}
  627. static PyObject* launch(PyObject* self, PyObject* args) {{
  628. int gridX, gridY, gridZ;
  629. uint64_t _stream;
  630. uint64_t _function;
  631. int launch_cooperative_grid;
  632. PyObject *profile_scratch_obj = NULL;
  633. PyObject *launch_enter_hook = NULL;
  634. PyObject *launch_exit_hook = NULL;
  635. PyObject *kernel_metadata = NULL;
  636. PyObject *launch_metadata = NULL;
  637. {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
  638. if(!PyArg_ParseTuple(args, \"{format}\", &launch_cooperative_grid,
  639. &gridX, &gridY, &gridZ, &_stream, &_function, &profile_scratch_obj,
  640. &kernel_metadata, &launch_metadata,
  641. &launch_enter_hook, &launch_exit_hook {args_list})) {{
  642. return NULL;
  643. }}
  644. // extract kernel metadata
  645. int num_warps, num_ctas, shared_memory;
  646. if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
  647. return NULL;
  648. }}
  649. // extract launch metadata
  650. if (launch_enter_hook != Py_None){{
  651. PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
  652. if (!ret)
  653. return NULL;
  654. Py_DECREF(ret);
  655. }}
  656. hipDeviceptr_t profile_scratch = 0;
  657. if (profile_scratch_obj != Py_None) {{
  658. DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
  659. if (!profile_scratch_info.valid) {{
  660. return NULL;
  661. }}
  662. profile_scratch = profile_scratch_info.dev_ptr;
  663. }}
  664. // raise exception asap
  665. {newline.join(tensor_desc_decls)}
  666. {newline.join(ptr_decls)}
  667. {newline.join(float_storage_decls)}
  668. _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
  669. if(launch_exit_hook != Py_None){{
  670. PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
  671. if (!ret)
  672. return NULL;
  673. Py_DECREF(ret);
  674. }}
  675. if(PyErr_Occurred()) {{
  676. return NULL;
  677. }}
  678. Py_RETURN_NONE;
  679. }}
  680. static PyMethodDef ModuleMethods[] = {{
  681. {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
  682. {{NULL, NULL, 0, NULL}} // sentinel
  683. }};
  684. static struct PyModuleDef ModuleDef = {{
  685. PyModuleDef_HEAD_INIT,
  686. \"__triton_launcher\",
  687. NULL, //documentation
  688. -1, //size
  689. ModuleMethods
  690. }};
  691. PyMODINIT_FUNC PyInit___triton_launcher(void) {{
  692. if (!initSymbolTable()) {{
  693. return NULL;
  694. }}
  695. PyObject *m = PyModule_Create(&ModuleDef);
  696. if(m == NULL) {{
  697. return NULL;
  698. }}
  699. data_ptr_str = PyUnicode_InternFromString("data_ptr");
  700. if(data_ptr_str == NULL) {{
  701. return NULL;
  702. }}
  703. PyObject* driver_mod = PyImport_ImportModule("triton.backends.amd.driver");
  704. if (driver_mod == NULL) {{
  705. return NULL;
  706. }}
  707. py_tdm_descriptor_type = PyObject_GetAttrString(driver_mod, "PyTDMDescriptor");
  708. if (py_tdm_descriptor_type == NULL) {{
  709. return NULL;
  710. }}
  711. PyModule_AddFunctions(m, ModuleMethods);
  712. return m;
  713. }}
  714. """
  715. return src
  716. def make_tensordesc_arg(arg, kernel_metadata, tensordesc_metadata):
  717. """
  718. Translate a tensor descriptor argument into the appropriate list of kernel
  719. arguments. If `tensordesc_metadata` is provided, we will create a
  720. TDMDescriptor object. Otherwise, we decompose the tensor descriptor into
  721. base pointer, shape, strides, and padding flag. In both cases, we append the
  722. shape and strides at the end to match the expected kernel signature.
  723. """
  724. if tensordesc_metadata is None:
  725. # Currently the host side tensor descriptors get decomposed in
  726. # the frontend to tensor desc, shape, and strides. We have no
  727. # way to use these shape and strides when processing tensor
  728. # descriptors which is why we provide our own decomposition
  729. # above. Sadly this means we have to pass the shape and strides
  730. # twice.
  731. return [arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides]
  732. shape = arg.shape
  733. strides = arg.strides
  734. base = arg.base.data_ptr()
  735. assert "elem_bits" in tensordesc_metadata and "block_size" in tensordesc_metadata
  736. elem_bits = tensordesc_metadata["elem_bits"]
  737. block_size = tensordesc_metadata["block_size"]
  738. pad_interval, pad_amount = 0, 0
  739. interval_padding_pairs = tensordesc_metadata.get("interval_padding_pairs", [])
  740. if interval_padding_pairs:
  741. assert len(interval_padding_pairs) == 1 and len(interval_padding_pairs[0]) == 2
  742. pad_interval, pad_amount = interval_padding_pairs[0]
  743. num_warps = kernel_metadata[0]
  744. driver = triton.runtime.driver.active
  745. assert isinstance(driver, HIPDriver)
  746. desc = driver.utils.create_tdm_descriptor(elem_bits, block_size, num_warps, pad_interval, pad_amount, shape,
  747. strides, base)
  748. return [desc, *shape, *strides]
  749. def wrap_handle_tensordesc(launcher, signature, tensordesc_metadata):
  750. """
  751. Wrap a kernel launcher function to handle tensor descriptor arguments.
  752. Use the provided `tensordesc_metadata` to determine whether to create
  753. TDMDescriptor objects or decompose the tensor descriptors.
  754. Args:
  755. launcher (callable): The original kernel launcher function.
  756. signature (Dict[int, str]): The kernel signature mapping argument indices to types.
  757. tensordesc_metadata (List[Dict] or None): The list of tensor descriptor metadata, following the order
  758. of tensor descriptor arguments. If None, decompose tensor descriptors.
  759. Returns:
  760. launcher (callable): The wrapped kernel launcher function.
  761. """
  762. has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
  763. if not has_tensor_desc_arg:
  764. return launcher
  765. tensordesc_indices = set(
  766. [i for i, sig in enumerate(signature.values()) if isinstance(sig, str) and sig.startswith("tensordesc")])
  767. assert not tensordesc_metadata or len(tensordesc_metadata) == len(tensordesc_indices)
  768. if not tensordesc_metadata:
  769. tensordesc_metadata = [None] * len(tensordesc_indices)
  770. def inner(*args):
  771. meta_args = args[:len(_BASE_ARGS_FORMAT)]
  772. raw_kernel_args = args[len(_BASE_ARGS_FORMAT):]
  773. final_args = []
  774. tensordesc_idx = 0
  775. for i, arg in enumerate(raw_kernel_args):
  776. if i in tensordesc_indices:
  777. tensordesc_args = make_tensordesc_arg(arg, meta_args[7], # kernel_metadata
  778. tensordesc_metadata[tensordesc_idx])
  779. final_args.extend(tensordesc_args)
  780. tensordesc_idx += 1
  781. else:
  782. final_args.append(arg)
  783. return launcher(*meta_args, *final_args)
  784. return inner
  785. class HIPLauncher(object):
  786. def __init__(self, src, metadata):
  787. constants = src.constants if hasattr(src, "constants") else dict()
  788. arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
  789. constants = {arg_idx(idx): value for idx, value in constants.items()}
  790. signature = {idx: value for idx, value in src.signature.items()}
  791. tensordesc_meta = getattr(metadata, "tensordesc_meta", None)
  792. src = make_launcher(constants, signature, metadata.warp_size, tensordesc_meta)
  793. mod = compile_module_from_src(src=src, name="__triton_launcher", include_dirs=include_dirs)
  794. self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta)
  795. self.launch_cooperative_grid = metadata.launch_cooperative_grid
  796. self.profile_scratch_size = metadata.profile_scratch_size
  797. self.profile_scratch_align = metadata.profile_scratch_align
  798. def __call__(self, gridX, gridY, gridZ, stream, function, *args):
  799. def allocate_scratch(size, align, allocator):
  800. if size > 0:
  801. grid_size = gridX * gridY * gridZ
  802. alloc_size = grid_size * size
  803. alloc_fn = allocator.get()
  804. return alloc_fn(alloc_size, align, stream)
  805. return None
  806. profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
  807. _allocation._profile_allocator)
  808. self.launch(self.launch_cooperative_grid, gridX, gridY, gridZ, stream, function, profile_scratch, *args)
  809. class HIPDriver(GPUDriver):
  810. def __init__(self):
  811. super().__init__()
  812. self.utils = HIPUtils()
  813. self.launcher_cls = HIPLauncher
  814. def get_device_interface(self):
  815. import torch
  816. return torch.cuda
  817. @staticmethod
  818. def is_active():
  819. try:
  820. import torch
  821. return torch.cuda.is_available() and (torch.version.hip is not None)
  822. except ImportError:
  823. return False
  824. def map_python_to_cpp_type(self, ty: str) -> str:
  825. return ty_to_cpp(ty)
  826. def get_current_target(self):
  827. device = self.get_current_device()
  828. device_properties = self.utils.get_device_properties(device)
  829. arch = knobs.runtime.override_arch or device_properties['arch']
  830. warp_size = device_properties['warpSize']
  831. return GPUTarget("hip", arch.split(':')[0], warp_size)
  832. def get_active_torch_device(self):
  833. import torch
  834. # when using hip devices, the device string in pytorch is "cuda"
  835. return torch.device("cuda", self.get_current_device())
  836. def get_benchmarker(self):
  837. from triton.testing import do_bench
  838. return do_bench
  839. def get_empty_cache_for_benchmark(self):
  840. import torch
  841. # It's the same as the Nvidia backend.
  842. cache_size = 256 * 1024 * 1024
  843. return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
  844. def clear_cache(self, cache):
  845. cache.zero_()