driver.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803
  1. import functools
  2. import os
  3. import subprocess
  4. import triton
  5. import re
  6. from pathlib import Path
  7. from triton import knobs
  8. from triton.runtime.build import compile_module_from_src
  9. from triton.runtime import _allocation
  10. from triton.backends.compiler import GPUTarget
  11. from triton.backends.driver import GPUDriver
  12. dirname = os.path.dirname(os.path.realpath(__file__))
  13. include_dirs = [os.path.join(dirname, "include")]
  14. if os.name == "nt":
  15. from triton.windows_utils import find_cuda
  16. _, cuda_inc_dirs, _ = find_cuda()
  17. include_dirs += cuda_inc_dirs
  18. libraries = ['cuda']
  19. else:
  20. libraries = ['libcuda.so.1']
  21. libdevice_dir = os.path.join(dirname, "lib")
  22. PyCUtensorMap = None
  23. @functools.lru_cache()
  24. def libcuda_dirs():
  25. if env_libcuda_path := knobs.nvidia.libcuda_path:
  26. return [env_libcuda_path]
  27. if os.name == "nt":
  28. _, _, cuda_lib_dirs = find_cuda()
  29. return cuda_lib_dirs
  30. libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
  31. # each line looks like the following:
  32. # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
  33. locs = [line.split()[-1] for line in libs.splitlines() if "libcuda.so.1" in line]
  34. dirs = [os.path.dirname(loc) for loc in locs]
  35. env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
  36. if env_ld_library_path and not dirs:
  37. dirs = [dir for dir in env_ld_library_path.split(":") if os.path.exists(os.path.join(dir, "libcuda.so.1"))]
  38. msg = 'libcuda.so cannot found!\n'
  39. if locs:
  40. msg += 'Possible files are located at %s.' % str(locs)
  41. msg += 'Please create a symlink of libcuda.so to any of the files.'
  42. else:
  43. msg += 'Please make sure GPU is set up and then run "/sbin/ldconfig"'
  44. msg += ' (requires sudo) to refresh the linker cache.'
  45. assert any(os.path.exists(os.path.join(path, 'libcuda.so.1')) for path in dirs), msg
  46. return dirs
  47. @functools.lru_cache()
  48. def library_dirs():
  49. return [libdevice_dir, *libcuda_dirs()]
  50. # ------------------------
  51. # Utils
  52. # ------------------------
  53. class CudaUtils(object):
  54. def __new__(cls):
  55. if not hasattr(cls, "instance"):
  56. cls.instance = super(CudaUtils, cls).__new__(cls)
  57. return cls.instance
  58. def __init__(self):
  59. mod = compile_module_from_src(
  60. src=Path(os.path.join(dirname, "driver.c")).read_text(),
  61. name="cuda_utils",
  62. library_dirs=library_dirs(),
  63. include_dirs=include_dirs,
  64. libraries=libraries,
  65. )
  66. global PyCUtensorMap
  67. PyCUtensorMap = mod.PyCUtensorMap
  68. self.load_binary = mod.load_binary
  69. self.get_device_properties = mod.get_device_properties
  70. self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
  71. self.set_printf_fifo_size = mod.set_printf_fifo_size
  72. self.fill_tma_descriptor = mod.fill_tma_descriptor
  73. # ------------------------
  74. # Launcher
  75. # ------------------------
  76. def ty_to_cpp(ty):
  77. if ty[0] == '*':
  78. return "CUdeviceptr"
  79. if ty.startswith("tensordesc"):
  80. return "CUtensorMap"
  81. return {
  82. "i1": "int8_t",
  83. "i8": "int8_t",
  84. "i16": "int16_t",
  85. "i32": "int32_t",
  86. "i64": "int64_t",
  87. "u1": "uint8_t",
  88. "u8": "uint8_t",
  89. "u16": "uint16_t",
  90. "u32": "uint32_t",
  91. "u64": "uint64_t",
  92. "fp16": "double",
  93. "bf16": "double",
  94. "fp32": "double",
  95. "f32": "double",
  96. "fp64": "double",
  97. "nvTmaDesc": "CUtensorMap",
  98. }[ty]
  99. FLOAT_STORAGE_TYPE = {
  100. "fp16": "uint16_t",
  101. "bf16": "uint16_t",
  102. "fp32": "uint32_t",
  103. "f32": "uint32_t",
  104. "fp64": "uint64_t",
  105. }
  106. FLOAT_PACK_FUNCTION = {
  107. "fp16": "pack_fp16",
  108. "bf16": "pack_bf16",
  109. "fp32": "pack_fp32",
  110. "f32": "pack_fp32",
  111. "fp64": "pack_fp64",
  112. }
  113. _BASE_ARGS_FORMAT = "iiiKKppOOOOOO"
  114. _BASE_ARGS_FORMAT_LEN = len(_BASE_ARGS_FORMAT)
  115. def make_launcher(constants, signature, tensordesc_meta):
  116. def _expand_signature(signature):
  117. output = []
  118. tensordesc_idx = 0
  119. # Expand tensor descriptor arguments into either nvTmaDesc, shape and
  120. # strides, or base pointer, shape and strides depending on whether the
  121. # kernel was lowered to use the nvTmaDesc or not.
  122. for sig in signature:
  123. if isinstance(sig, str) and sig.startswith("tensordesc"):
  124. meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
  125. tensordesc_idx += 1
  126. match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", sig)
  127. dtype = match.group(1)
  128. shape = match.group(2)
  129. ndim = shape.count(",") + 1
  130. if meta is None:
  131. output.append("*" + dtype)
  132. # Currently the host side tensor descriptors get passed in as a
  133. # tensor desc, shape, and strides. We have no way to use these
  134. # shape and strides when processing tensor descriptors which is
  135. # why we provide our own decomposition above. Sadly this means
  136. # we have to pass the shape and strides twice.
  137. for _ in range(2 * ndim):
  138. output.append("i64")
  139. output.append("i1")
  140. else:
  141. output.append("nvTmaDesc")
  142. for _ in range(ndim):
  143. output.append("i32")
  144. for _ in range(ndim):
  145. output.append("i64")
  146. else:
  147. output.append(sig)
  148. assert not tensordesc_meta or tensordesc_idx == len(tensordesc_meta)
  149. return output
  150. def _flatten_signature(sig, output):
  151. # Flatten tuples
  152. if isinstance(sig, tuple):
  153. for x in sig:
  154. _flatten_signature(x, output)
  155. else:
  156. output.append(sig)
  157. def _extracted_type(ty):
  158. if isinstance(ty, tuple):
  159. val = ','.join(map(_extracted_type, ty))
  160. return f"[{val}]"
  161. if ty[0] == '*':
  162. return "PyObject*"
  163. if ty in ("constexpr", "nvTmaDesc"):
  164. return "PyObject*"
  165. return ty_to_cpp(ty)
  166. def format_of(ty):
  167. if isinstance(ty, tuple):
  168. val = ''.join(map(format_of, ty))
  169. return f"({val})"
  170. if ty[0] == '*':
  171. return "O"
  172. if ty in ("constexpr", "nvTmaDesc"):
  173. return "O"
  174. if ty.startswith("tensordesc"):
  175. return "O"
  176. return {
  177. "double": "d",
  178. "long": "l",
  179. "int8_t": "b",
  180. "int16_t": "h",
  181. "int32_t": "i",
  182. "int64_t": "L",
  183. "uint8_t": "B",
  184. "uint16_t": "H",
  185. "uint32_t": "I",
  186. "uint64_t": "K",
  187. }[ty_to_cpp(ty)]
  188. expand_signature = _expand_signature(signature.values())
  189. signature = {i: s for i, s in enumerate(expand_signature)}
  190. args_format = ''.join([format_of(ty) for ty in signature.values()])
  191. format = _BASE_ARGS_FORMAT + args_format
  192. flat_signature = []
  193. for sig in signature.values():
  194. _flatten_signature(sig, flat_signature)
  195. signature = {i: s for i, s in enumerate(flat_signature)}
  196. args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
  197. # Record the end of regular arguments;
  198. # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
  199. arg_decl_list = []
  200. for i, ty in signature.items():
  201. if ty == "constexpr":
  202. continue
  203. if ty in FLOAT_STORAGE_TYPE:
  204. arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
  205. else:
  206. arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
  207. arg_decls = ', '.join(arg_decl_list)
  208. internal_args_list = []
  209. for i, ty in signature.items():
  210. if ty[0] == "*":
  211. internal_args_list.append(f"ptr_info{i}.dev_ptr")
  212. elif ty in FLOAT_STORAGE_TYPE:
  213. internal_args_list.append(f"_arg{i}_storage")
  214. elif ty == "nvTmaDesc":
  215. # Note: we have to dereference the pointer
  216. internal_args_list.append(f"*tma_ptr{i}")
  217. elif ty != "constexpr":
  218. internal_args_list.append(f"_arg{i}")
  219. params = range(len(signature))
  220. # generate glue code
  221. newline = '\n '
  222. ptr_decls = [
  223. f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;"
  224. for i, ty in signature.items()
  225. if ty[0] == "*"
  226. ]
  227. tma_decls = [
  228. f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" for i, ty in signature.items()
  229. if ty == "nvTmaDesc"
  230. ]
  231. float_storage_decls = [
  232. f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
  233. for i, ty in signature.items()
  234. if ty in FLOAT_STORAGE_TYPE
  235. ]
  236. params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
  237. params.append("&global_scratch")
  238. params.append("&profile_scratch")
  239. src = f"""
  240. #define _CRT_SECURE_NO_WARNINGS
  241. #include \"cuda.h\"
  242. #ifndef _WIN32
  243. #include <dlfcn.h>
  244. #else
  245. #define WIN32_LEAN_AND_MEAN
  246. #include <windows.h>
  247. #endif
  248. #include <stdbool.h>
  249. #include <stdlib.h>
  250. #define PY_SSIZE_T_CLEAN
  251. #include <Python.h>
  252. typedef struct {{
  253. PyObject_HEAD
  254. _Alignas(128) CUtensorMap tensorMap;
  255. }} PyCUtensorMapObject;
  256. static inline void gpuAssert(CUresult code, const char *file, int line)
  257. {{
  258. if (code != CUDA_SUCCESS)
  259. {{
  260. const char* prefix = "Triton Error [CUDA]: ";
  261. const char* str;
  262. cuGetErrorString(code, &str);
  263. char err[1024] = {{0}};
  264. strcat(err, prefix);
  265. strcat(err, str);
  266. PyGILState_STATE gil_state;
  267. gil_state = PyGILState_Ensure();
  268. PyErr_SetString(PyExc_RuntimeError, err);
  269. PyGILState_Release(gil_state);
  270. }}
  271. }}
  272. #define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
  273. typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra);
  274. #ifndef _WIN32
  275. static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
  276. // Open the shared library
  277. void* handle = dlopen("libcuda.so.1", RTLD_LAZY);
  278. if (!handle) {{
  279. PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1");
  280. return NULL;
  281. }}
  282. // Clear any existing error
  283. dlerror();
  284. cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx");
  285. // Check for errors
  286. const char *dlsym_error = dlerror();
  287. if (dlsym_error) {{
  288. PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so.1");
  289. return NULL;
  290. }}
  291. return cuLaunchKernelExHandle;
  292. }}
  293. #else
  294. static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
  295. // Open the shared library
  296. HMODULE handle = LoadLibraryA("nvcuda.dll");
  297. if (!handle) {{
  298. PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll");
  299. return NULL;
  300. }}
  301. cuLaunchKernelEx_t cuLaunchKernelExHandle =
  302. (cuLaunchKernelEx_t)GetProcAddress((HMODULE)handle, "cuLaunchKernelEx");
  303. // Check for errors
  304. long error = GetLastError();
  305. if (error) {{
  306. PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from nvcuda.dll");
  307. return NULL;
  308. }}
  309. return cuLaunchKernelExHandle;
  310. }}
  311. #endif
  312. static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
  313. void *params[] = {{ {', '.join(params)} }};
  314. if (gridX*gridY*gridZ > 0) {{
  315. // 4 attributes that we can currently pass maximum
  316. CUlaunchAttribute launchAttr[4];
  317. static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
  318. if (cuLaunchKernelExHandle == NULL) {{
  319. cuLaunchKernelExHandle = getLaunchKernelExHandle();
  320. }}
  321. CUlaunchConfig config;
  322. config.gridDimX = gridX * num_ctas;
  323. config.gridDimY = gridY;
  324. config.gridDimZ = gridZ;
  325. config.blockDimX = 32 * num_warps;
  326. config.blockDimY = 1;
  327. config.blockDimZ = 1;
  328. config.sharedMemBytes = shared_memory;
  329. config.hStream = stream;
  330. config.attrs = launchAttr;
  331. int num_attrs = 0;
  332. if (launch_pdl != 0) {{
  333. CUlaunchAttribute pdlAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION, .value = 1}};
  334. launchAttr[num_attrs] = pdlAttr;
  335. ++num_attrs;
  336. }}
  337. if (launch_cooperative_grid != 0) {{
  338. CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
  339. launchAttr[num_attrs] = coopAttr;
  340. ++num_attrs;
  341. }}
  342. if (num_ctas != 1) {{
  343. CUlaunchAttribute clusterAttr = {{}};
  344. clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
  345. clusterAttr.value.clusterDim.x = num_ctas;
  346. clusterAttr.value.clusterDim.y = 1;
  347. clusterAttr.value.clusterDim.z = 1;
  348. launchAttr[num_attrs] = clusterAttr;
  349. ++num_attrs;
  350. CUlaunchAttribute clusterSchedulingAttr = {{}};
  351. clusterSchedulingAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
  352. clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
  353. launchAttr[num_attrs] = clusterSchedulingAttr;
  354. ++num_attrs;
  355. }}
  356. // num_ctas == 16 is non-portable. Does work for H100 and B200 tho
  357. config.numAttrs = num_attrs;
  358. if (num_ctas == 16) {{
  359. CUDA_CHECK(cuFuncSetAttribute(
  360. function,
  361. CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED,
  362. 1
  363. ));
  364. }}
  365. CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
  366. }}
  367. }}
  368. typedef struct _DevicePtrInfo {{
  369. CUdeviceptr dev_ptr;
  370. bool valid;
  371. }} DevicePtrInfo;
  372. static PyObject* data_ptr_str = NULL;
  373. static PyObject* py_tensor_map_type = NULL;
  374. static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
  375. DevicePtrInfo ptr_info;
  376. ptr_info.dev_ptr = 0;
  377. ptr_info.valid = true;
  378. if (PyLong_Check(obj)) {{
  379. ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj);
  380. return ptr_info;
  381. }}
  382. if (obj == Py_None) {{
  383. // valid nullptr
  384. return ptr_info;
  385. }}
  386. PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
  387. if (!ret) {{
  388. PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
  389. ptr_info.valid = false;
  390. goto cleanup;
  391. }}
  392. if (!PyLong_Check(ret)) {{
  393. PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
  394. ptr_info.valid = false;
  395. goto cleanup;
  396. }}
  397. ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
  398. if(!ptr_info.dev_ptr)
  399. return ptr_info;
  400. uint64_t dev_ptr;
  401. int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
  402. if (status == CUDA_ERROR_INVALID_VALUE) {{
  403. PyErr_Format(PyExc_ValueError,
  404. "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
  405. ptr_info.valid = false;
  406. }} else if (status != CUDA_SUCCESS) {{
  407. CUDA_CHECK(status); // Catch any other cuda API errors
  408. ptr_info.valid = false;
  409. }}
  410. ptr_info.dev_ptr = dev_ptr;
  411. cleanup:
  412. Py_XDECREF(ret);
  413. return ptr_info;
  414. }}
  415. static inline CUtensorMap* getTmaDesc(PyObject *obj) {{
  416. if (sizeof(CUtensorMap*) != 8) {{
  417. PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation");
  418. return NULL;
  419. }}
  420. if (Py_TYPE(obj) != (PyTypeObject*)py_tensor_map_type) {{
  421. PyErr_Format(PyExc_TypeError, "object must be of type PyCUtensorMap, got %s", Py_TYPE(obj)->tp_name);
  422. return NULL;
  423. }}
  424. CUtensorMap* map = &((PyCUtensorMapObject*)obj)->tensorMap;
  425. uintptr_t align_128 = (uintptr_t)map & (128 - 1);
  426. if (align_128 != 0) {{
  427. PyErr_Format(PyExc_ValueError, "CUtensorMap must be aligned to 128B, but got (&map) mod 128 = %ld", align_128);
  428. return NULL;
  429. }}
  430. return map;
  431. }}
  432. static void ensureCudaContext() {{
  433. CUcontext pctx;
  434. CUDA_CHECK(cuCtxGetCurrent(&pctx));
  435. if (!pctx) {{
  436. // Ensure device context.
  437. CUdevice device;
  438. CUDA_CHECK(cuDeviceGet(&device, 0));
  439. CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
  440. CUDA_CHECK(cuCtxSetCurrent(pctx));
  441. }}
  442. }}
  443. static uint16_t pack_fp16(double f) {{
  444. uint16_t result;
  445. // from https://github.com/python/pythoncapi-compat
  446. #if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
  447. _PyFloat_Pack2(f, (unsigned char*)&result, 1);
  448. #else
  449. PyFloat_Pack2(f, (unsigned char*)&result, 1);
  450. #endif
  451. return result;
  452. }}
  453. static uint16_t pack_bf16(double f) {{
  454. float f32 = (float)f;
  455. uint32_t u32 = *(uint32_t*)&f32;
  456. return (uint16_t)(u32 >> 16);
  457. }}
  458. static uint32_t pack_fp32(double f) {{
  459. float f32 = (float)f;
  460. return *(uint32_t*)&f32;
  461. }}
  462. static uint64_t pack_fp64(double f) {{
  463. return *(uint64_t*)&f;
  464. }}
  465. static PyObject* launch(PyObject* self, PyObject* args) {{
  466. // ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes
  467. ensureCudaContext();
  468. int gridX, gridY, gridZ;
  469. uint64_t _stream;
  470. uint64_t _function;
  471. int launch_cooperative_grid;
  472. int launch_pdl;
  473. PyObject *launch_enter_hook = NULL;
  474. PyObject *launch_exit_hook = NULL;
  475. PyObject *kernel_metadata = NULL;
  476. PyObject *launch_metadata = NULL;
  477. PyObject *global_scratch_obj = NULL;
  478. PyObject *profile_scratch_obj = NULL;
  479. {newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
  480. if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
  481. &_stream, &_function, &launch_cooperative_grid, &launch_pdl, &global_scratch_obj, &profile_scratch_obj,
  482. &kernel_metadata, &launch_metadata,
  483. &launch_enter_hook, &launch_exit_hook{args_list})) {{
  484. return NULL;
  485. }}
  486. int num_warps, num_ctas, shared_memory;
  487. if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
  488. PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple");
  489. return NULL;
  490. }}
  491. // extract launch metadata
  492. if (launch_enter_hook != Py_None){{
  493. PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
  494. if (!ret)
  495. return NULL;
  496. Py_DECREF(ret);
  497. }}
  498. CUdeviceptr global_scratch = 0;
  499. if (global_scratch_obj != Py_None) {{
  500. DevicePtrInfo global_scratch_info = getPointer(global_scratch_obj, -1);
  501. if (!global_scratch_info.valid) {{
  502. return NULL;
  503. }}
  504. global_scratch = global_scratch_info.dev_ptr;
  505. }}
  506. CUdeviceptr profile_scratch = 0;
  507. if (profile_scratch_obj != Py_None) {{
  508. DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
  509. if (!profile_scratch_info.valid) {{
  510. return NULL;
  511. }}
  512. profile_scratch = profile_scratch_info.dev_ptr;
  513. }}
  514. // raise exception asap
  515. {newline.join(ptr_decls)}
  516. {newline.join(tma_decls)}
  517. {newline.join(float_storage_decls)}
  518. Py_BEGIN_ALLOW_THREADS;
  519. _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
  520. Py_END_ALLOW_THREADS;
  521. if (PyErr_Occurred()) {{
  522. return NULL;
  523. }}
  524. if(launch_exit_hook != Py_None){{
  525. PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
  526. if (!ret)
  527. return NULL;
  528. Py_DECREF(ret);
  529. }}
  530. Py_RETURN_NONE;
  531. }}
  532. static PyMethodDef ModuleMethods[] = {{
  533. {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
  534. {{NULL, NULL, 0, NULL}} // sentinel
  535. }};
  536. static struct PyModuleDef ModuleDef = {{
  537. PyModuleDef_HEAD_INIT,
  538. \"__triton_launcher\",
  539. NULL, //documentation
  540. -1, //size
  541. ModuleMethods
  542. }};
  543. PyMODINIT_FUNC PyInit___triton_launcher(void) {{
  544. data_ptr_str = PyUnicode_InternFromString("data_ptr");
  545. if(data_ptr_str == NULL) {{
  546. return NULL;
  547. }}
  548. PyObject* driver_mod = PyImport_ImportModule("triton.backends.nvidia.driver");
  549. if (driver_mod == NULL) {{
  550. return NULL;
  551. }}
  552. py_tensor_map_type = PyObject_GetAttrString(driver_mod, "PyCUtensorMap");
  553. if (py_tensor_map_type == NULL) {{
  554. return NULL;
  555. }}
  556. PyObject *m = PyModule_Create(&ModuleDef);
  557. if(m == NULL) {{
  558. return NULL;
  559. }}
  560. PyModule_AddFunctions(m, ModuleMethods);
  561. return m;
  562. }}
  563. """
  564. return src
  565. # The TMA dtype enum values are slightly different on host vs device...
  566. TMA_DTYPE_DEVICE_TO_HOST = dict((i, i) for i in range(16))
  567. TMA_DTYPE_DEVICE_TO_HOST[8] = 10
  568. TMA_DTYPE_DEVICE_TO_HOST[9] = 8
  569. TMA_DTYPE_DEVICE_TO_HOST[10] = 9
  570. def make_tensordesc_arg(arg, metadata):
  571. if metadata is None:
  572. # Currently the host side tensor descriptors get decomposed in
  573. # the frontend to tensor desc, shape, and strides. We have no
  574. # way to use these shape and strides when processing tensor
  575. # descriptors which is why we provide our own decomposition
  576. # above. Sadly this means we have to pass the shape and strides
  577. # twice.
  578. return [arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides]
  579. swizzle = metadata["swizzle"]
  580. elem_size = metadata["elem_size"]
  581. elem_type = metadata["elem_type"]
  582. block_size = metadata["block_size"]
  583. fp4_padded = metadata["fp4_padded"]
  584. shape = arg.shape
  585. strides = arg.strides
  586. assert strides[-1] == 1
  587. padding = 1 if arg.padding == "nan" else 0
  588. if fp4_padded:
  589. shape = list(shape)
  590. shape[-1] *= 2
  591. cu_tensor_map = triton.runtime.driver.active.utils.fill_tma_descriptor(
  592. arg.base.data_ptr(),
  593. swizzle,
  594. elem_size,
  595. TMA_DTYPE_DEVICE_TO_HOST[elem_type],
  596. block_size,
  597. shape,
  598. strides,
  599. padding,
  600. )
  601. return [cu_tensor_map, *shape, *strides]
  602. def wrap_handle_tensordesc(launcher, signature, tensordesc_meta):
  603. has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
  604. if not has_tensor_desc_arg:
  605. return launcher
  606. tensordesc_indices = set(
  607. [i for i, sig in enumerate(signature.values()) if isinstance(sig, str) and sig.startswith("tensordesc")])
  608. assert not tensordesc_meta or len(tensordesc_meta) == len(tensordesc_indices)
  609. if not tensordesc_meta:
  610. tensordesc_meta = [None] * len(tensordesc_indices)
  611. def inner(*args):
  612. final_args = list(args[:_BASE_ARGS_FORMAT_LEN])
  613. tensordesc_idx = 0
  614. for i, arg in enumerate(args[_BASE_ARGS_FORMAT_LEN:]):
  615. if i in tensordesc_indices:
  616. final_args.extend(make_tensordesc_arg(arg, tensordesc_meta[tensordesc_idx]))
  617. tensordesc_idx += 1
  618. else:
  619. final_args.append(arg)
  620. return launcher(*final_args)
  621. return inner
  622. class CudaLauncher(object):
  623. def __init__(self, src, metadata):
  624. constants = src.constants if hasattr(src, "constants") else dict()
  625. arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
  626. constants = {arg_idx(idx): value for idx, value in constants.items()}
  627. signature = {idx: value for idx, value in src.signature.items()}
  628. tensordesc_meta = getattr(metadata, "tensordesc_meta", None)
  629. src = make_launcher(constants, signature, tensordesc_meta)
  630. mod = compile_module_from_src(
  631. src=src,
  632. name="__triton_launcher",
  633. library_dirs=library_dirs(),
  634. include_dirs=include_dirs,
  635. libraries=libraries,
  636. )
  637. self.num_ctas = getattr(metadata, "num_ctas", 1)
  638. self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta)
  639. self.global_scratch_size = metadata.global_scratch_size
  640. self.global_scratch_align = metadata.global_scratch_align
  641. self.profile_scratch_size = metadata.profile_scratch_size
  642. self.profile_scratch_align = metadata.profile_scratch_align
  643. self.launch_cooperative_grid = metadata.launch_cooperative_grid
  644. self.launch_pdl = metadata.launch_pdl
  645. def __call__(self, gridX, gridY, gridZ, stream, function, *args):
  646. def allocate_scratch(size, align, allocator):
  647. if size > 0:
  648. grid_size = gridX * gridY * gridZ
  649. alloc_size = grid_size * self.num_ctas * size
  650. alloc_fn = allocator.get()
  651. return alloc_fn(alloc_size, align, stream)
  652. return None
  653. global_scratch = allocate_scratch(self.global_scratch_size, self.global_scratch_align, _allocation._allocator)
  654. profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
  655. _allocation._profile_allocator)
  656. self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
  657. global_scratch, profile_scratch, *args)
  658. class CudaDriver(GPUDriver):
  659. def __init__(self):
  660. self.utils = CudaUtils() # TODO: make static
  661. self.launcher_cls = CudaLauncher
  662. super().__init__()
  663. def get_current_target(self):
  664. device = self.get_current_device()
  665. capability = self.get_device_capability(device)
  666. capability = capability[0] * 10 + capability[1]
  667. warp_size = 32
  668. return GPUTarget("cuda", capability, warp_size)
  669. def get_active_torch_device(self):
  670. import torch
  671. return torch.device("cuda", self.get_current_device())
  672. def get_device_interface(self):
  673. import torch
  674. return torch.cuda
  675. @staticmethod
  676. def is_active():
  677. try:
  678. import torch
  679. return torch.cuda.is_available() and (torch.version.hip is None)
  680. except ImportError:
  681. return False
  682. def map_python_to_cpp_type(self, ty: str) -> str:
  683. return ty_to_cpp(ty)
  684. def get_benchmarker(self):
  685. from triton.testing import do_bench
  686. return do_bench
  687. def get_empty_cache_for_benchmark(self):
  688. import torch
  689. # We maintain a buffer of 256 MB that we clear
  690. # before each kernel call to make sure that the L2 cache
  691. # doesn't contain any input data before the run
  692. cache_size = 256 * 1024 * 1024
  693. return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
  694. def clear_cache(self, cache):
  695. cache.zero_()