| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803 |
- import functools
- import os
- import subprocess
- import triton
- import re
- from pathlib import Path
- from triton import knobs
- from triton.runtime.build import compile_module_from_src
- from triton.runtime import _allocation
- from triton.backends.compiler import GPUTarget
- from triton.backends.driver import GPUDriver
- dirname = os.path.dirname(os.path.realpath(__file__))
- include_dirs = [os.path.join(dirname, "include")]
- if os.name == "nt":
- from triton.windows_utils import find_cuda
- _, cuda_inc_dirs, _ = find_cuda()
- include_dirs += cuda_inc_dirs
- libraries = ['cuda']
- else:
- libraries = ['libcuda.so.1']
- libdevice_dir = os.path.join(dirname, "lib")
- PyCUtensorMap = None
- @functools.lru_cache()
- def libcuda_dirs():
- if env_libcuda_path := knobs.nvidia.libcuda_path:
- return [env_libcuda_path]
- if os.name == "nt":
- _, _, cuda_lib_dirs = find_cuda()
- return cuda_lib_dirs
- libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
- # each line looks like the following:
- # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
- locs = [line.split()[-1] for line in libs.splitlines() if "libcuda.so.1" in line]
- dirs = [os.path.dirname(loc) for loc in locs]
- env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
- if env_ld_library_path and not dirs:
- dirs = [dir for dir in env_ld_library_path.split(":") if os.path.exists(os.path.join(dir, "libcuda.so.1"))]
- msg = 'libcuda.so cannot found!\n'
- if locs:
- msg += 'Possible files are located at %s.' % str(locs)
- msg += 'Please create a symlink of libcuda.so to any of the files.'
- else:
- msg += 'Please make sure GPU is set up and then run "/sbin/ldconfig"'
- msg += ' (requires sudo) to refresh the linker cache.'
- assert any(os.path.exists(os.path.join(path, 'libcuda.so.1')) for path in dirs), msg
- return dirs
- @functools.lru_cache()
- def library_dirs():
- return [libdevice_dir, *libcuda_dirs()]
- # ------------------------
- # Utils
- # ------------------------
- class CudaUtils(object):
- def __new__(cls):
- if not hasattr(cls, "instance"):
- cls.instance = super(CudaUtils, cls).__new__(cls)
- return cls.instance
- def __init__(self):
- mod = compile_module_from_src(
- src=Path(os.path.join(dirname, "driver.c")).read_text(),
- name="cuda_utils",
- library_dirs=library_dirs(),
- include_dirs=include_dirs,
- libraries=libraries,
- )
- global PyCUtensorMap
- PyCUtensorMap = mod.PyCUtensorMap
- self.load_binary = mod.load_binary
- self.get_device_properties = mod.get_device_properties
- self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
- self.set_printf_fifo_size = mod.set_printf_fifo_size
- self.fill_tma_descriptor = mod.fill_tma_descriptor
- # ------------------------
- # Launcher
- # ------------------------
- def ty_to_cpp(ty):
- if ty[0] == '*':
- return "CUdeviceptr"
- if ty.startswith("tensordesc"):
- return "CUtensorMap"
- return {
- "i1": "int8_t",
- "i8": "int8_t",
- "i16": "int16_t",
- "i32": "int32_t",
- "i64": "int64_t",
- "u1": "uint8_t",
- "u8": "uint8_t",
- "u16": "uint16_t",
- "u32": "uint32_t",
- "u64": "uint64_t",
- "fp16": "double",
- "bf16": "double",
- "fp32": "double",
- "f32": "double",
- "fp64": "double",
- "nvTmaDesc": "CUtensorMap",
- }[ty]
- FLOAT_STORAGE_TYPE = {
- "fp16": "uint16_t",
- "bf16": "uint16_t",
- "fp32": "uint32_t",
- "f32": "uint32_t",
- "fp64": "uint64_t",
- }
- FLOAT_PACK_FUNCTION = {
- "fp16": "pack_fp16",
- "bf16": "pack_bf16",
- "fp32": "pack_fp32",
- "f32": "pack_fp32",
- "fp64": "pack_fp64",
- }
- _BASE_ARGS_FORMAT = "iiiKKppOOOOOO"
- _BASE_ARGS_FORMAT_LEN = len(_BASE_ARGS_FORMAT)
- def make_launcher(constants, signature, tensordesc_meta):
- def _expand_signature(signature):
- output = []
- tensordesc_idx = 0
- # Expand tensor descriptor arguments into either nvTmaDesc, shape and
- # strides, or base pointer, shape and strides depending on whether the
- # kernel was lowered to use the nvTmaDesc or not.
- for sig in signature:
- if isinstance(sig, str) and sig.startswith("tensordesc"):
- meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
- tensordesc_idx += 1
- match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", sig)
- dtype = match.group(1)
- shape = match.group(2)
- ndim = shape.count(",") + 1
- if meta is None:
- output.append("*" + dtype)
- # Currently the host side tensor descriptors get passed in as a
- # tensor desc, shape, and strides. We have no way to use these
- # shape and strides when processing tensor descriptors which is
- # why we provide our own decomposition above. Sadly this means
- # we have to pass the shape and strides twice.
- for _ in range(2 * ndim):
- output.append("i64")
- output.append("i1")
- else:
- output.append("nvTmaDesc")
- for _ in range(ndim):
- output.append("i32")
- for _ in range(ndim):
- output.append("i64")
- else:
- output.append(sig)
- assert not tensordesc_meta or tensordesc_idx == len(tensordesc_meta)
- return output
- def _flatten_signature(sig, output):
- # Flatten tuples
- if isinstance(sig, tuple):
- for x in sig:
- _flatten_signature(x, output)
- else:
- output.append(sig)
- def _extracted_type(ty):
- if isinstance(ty, tuple):
- val = ','.join(map(_extracted_type, ty))
- return f"[{val}]"
- if ty[0] == '*':
- return "PyObject*"
- if ty in ("constexpr", "nvTmaDesc"):
- return "PyObject*"
- return ty_to_cpp(ty)
- def format_of(ty):
- if isinstance(ty, tuple):
- val = ''.join(map(format_of, ty))
- return f"({val})"
- if ty[0] == '*':
- return "O"
- if ty in ("constexpr", "nvTmaDesc"):
- return "O"
- if ty.startswith("tensordesc"):
- return "O"
- return {
- "double": "d",
- "long": "l",
- "int8_t": "b",
- "int16_t": "h",
- "int32_t": "i",
- "int64_t": "L",
- "uint8_t": "B",
- "uint16_t": "H",
- "uint32_t": "I",
- "uint64_t": "K",
- }[ty_to_cpp(ty)]
- expand_signature = _expand_signature(signature.values())
- signature = {i: s for i, s in enumerate(expand_signature)}
- args_format = ''.join([format_of(ty) for ty in signature.values()])
- format = _BASE_ARGS_FORMAT + args_format
- flat_signature = []
- for sig in signature.values():
- _flatten_signature(sig, flat_signature)
- signature = {i: s for i, s in enumerate(flat_signature)}
- args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
- # Record the end of regular arguments;
- # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
- arg_decl_list = []
- for i, ty in signature.items():
- if ty == "constexpr":
- continue
- if ty in FLOAT_STORAGE_TYPE:
- arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
- else:
- arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
- arg_decls = ', '.join(arg_decl_list)
- internal_args_list = []
- for i, ty in signature.items():
- if ty[0] == "*":
- internal_args_list.append(f"ptr_info{i}.dev_ptr")
- elif ty in FLOAT_STORAGE_TYPE:
- internal_args_list.append(f"_arg{i}_storage")
- elif ty == "nvTmaDesc":
- # Note: we have to dereference the pointer
- internal_args_list.append(f"*tma_ptr{i}")
- elif ty != "constexpr":
- internal_args_list.append(f"_arg{i}")
- params = range(len(signature))
- # generate glue code
- newline = '\n '
- ptr_decls = [
- f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;"
- for i, ty in signature.items()
- if ty[0] == "*"
- ]
- tma_decls = [
- f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" for i, ty in signature.items()
- if ty == "nvTmaDesc"
- ]
- float_storage_decls = [
- f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
- for i, ty in signature.items()
- if ty in FLOAT_STORAGE_TYPE
- ]
- params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
- params.append("&global_scratch")
- params.append("&profile_scratch")
- src = f"""
- #define _CRT_SECURE_NO_WARNINGS
- #include \"cuda.h\"
- #ifndef _WIN32
- #include <dlfcn.h>
- #else
- #define WIN32_LEAN_AND_MEAN
- #include <windows.h>
- #endif
- #include <stdbool.h>
- #include <stdlib.h>
- #define PY_SSIZE_T_CLEAN
- #include <Python.h>
- typedef struct {{
- PyObject_HEAD
- _Alignas(128) CUtensorMap tensorMap;
- }} PyCUtensorMapObject;
- static inline void gpuAssert(CUresult code, const char *file, int line)
- {{
- if (code != CUDA_SUCCESS)
- {{
- const char* prefix = "Triton Error [CUDA]: ";
- const char* str;
- cuGetErrorString(code, &str);
- char err[1024] = {{0}};
- strcat(err, prefix);
- strcat(err, str);
- PyGILState_STATE gil_state;
- gil_state = PyGILState_Ensure();
- PyErr_SetString(PyExc_RuntimeError, err);
- PyGILState_Release(gil_state);
- }}
- }}
- #define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
- typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra);
- #ifndef _WIN32
- static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
- // Open the shared library
- void* handle = dlopen("libcuda.so.1", RTLD_LAZY);
- if (!handle) {{
- PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1");
- return NULL;
- }}
- // Clear any existing error
- dlerror();
- cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx");
- // Check for errors
- const char *dlsym_error = dlerror();
- if (dlsym_error) {{
- PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so.1");
- return NULL;
- }}
- return cuLaunchKernelExHandle;
- }}
- #else
- static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
- // Open the shared library
- HMODULE handle = LoadLibraryA("nvcuda.dll");
- if (!handle) {{
- PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll");
- return NULL;
- }}
- cuLaunchKernelEx_t cuLaunchKernelExHandle =
- (cuLaunchKernelEx_t)GetProcAddress((HMODULE)handle, "cuLaunchKernelEx");
- // Check for errors
- long error = GetLastError();
- if (error) {{
- PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from nvcuda.dll");
- return NULL;
- }}
- return cuLaunchKernelExHandle;
- }}
- #endif
- static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
- void *params[] = {{ {', '.join(params)} }};
- if (gridX*gridY*gridZ > 0) {{
- // 4 attributes that we can currently pass maximum
- CUlaunchAttribute launchAttr[4];
- static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
- if (cuLaunchKernelExHandle == NULL) {{
- cuLaunchKernelExHandle = getLaunchKernelExHandle();
- }}
- CUlaunchConfig config;
- config.gridDimX = gridX * num_ctas;
- config.gridDimY = gridY;
- config.gridDimZ = gridZ;
- config.blockDimX = 32 * num_warps;
- config.blockDimY = 1;
- config.blockDimZ = 1;
- config.sharedMemBytes = shared_memory;
- config.hStream = stream;
- config.attrs = launchAttr;
- int num_attrs = 0;
- if (launch_pdl != 0) {{
- CUlaunchAttribute pdlAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION, .value = 1}};
- launchAttr[num_attrs] = pdlAttr;
- ++num_attrs;
- }}
- if (launch_cooperative_grid != 0) {{
- CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
- launchAttr[num_attrs] = coopAttr;
- ++num_attrs;
- }}
- if (num_ctas != 1) {{
- CUlaunchAttribute clusterAttr = {{}};
- clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
- clusterAttr.value.clusterDim.x = num_ctas;
- clusterAttr.value.clusterDim.y = 1;
- clusterAttr.value.clusterDim.z = 1;
- launchAttr[num_attrs] = clusterAttr;
- ++num_attrs;
- CUlaunchAttribute clusterSchedulingAttr = {{}};
- clusterSchedulingAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
- clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
- launchAttr[num_attrs] = clusterSchedulingAttr;
- ++num_attrs;
- }}
- // num_ctas == 16 is non-portable. Does work for H100 and B200 tho
- config.numAttrs = num_attrs;
- if (num_ctas == 16) {{
- CUDA_CHECK(cuFuncSetAttribute(
- function,
- CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED,
- 1
- ));
- }}
- CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
- }}
- }}
- typedef struct _DevicePtrInfo {{
- CUdeviceptr dev_ptr;
- bool valid;
- }} DevicePtrInfo;
- static PyObject* data_ptr_str = NULL;
- static PyObject* py_tensor_map_type = NULL;
- static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
- DevicePtrInfo ptr_info;
- ptr_info.dev_ptr = 0;
- ptr_info.valid = true;
- if (PyLong_Check(obj)) {{
- ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj);
- return ptr_info;
- }}
- if (obj == Py_None) {{
- // valid nullptr
- return ptr_info;
- }}
- PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
- if (!ret) {{
- PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
- ptr_info.valid = false;
- goto cleanup;
- }}
- if (!PyLong_Check(ret)) {{
- PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
- ptr_info.valid = false;
- goto cleanup;
- }}
- ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
- if(!ptr_info.dev_ptr)
- return ptr_info;
- uint64_t dev_ptr;
- int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
- if (status == CUDA_ERROR_INVALID_VALUE) {{
- PyErr_Format(PyExc_ValueError,
- "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
- ptr_info.valid = false;
- }} else if (status != CUDA_SUCCESS) {{
- CUDA_CHECK(status); // Catch any other cuda API errors
- ptr_info.valid = false;
- }}
- ptr_info.dev_ptr = dev_ptr;
- cleanup:
- Py_XDECREF(ret);
- return ptr_info;
- }}
- static inline CUtensorMap* getTmaDesc(PyObject *obj) {{
- if (sizeof(CUtensorMap*) != 8) {{
- PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation");
- return NULL;
- }}
- if (Py_TYPE(obj) != (PyTypeObject*)py_tensor_map_type) {{
- PyErr_Format(PyExc_TypeError, "object must be of type PyCUtensorMap, got %s", Py_TYPE(obj)->tp_name);
- return NULL;
- }}
- CUtensorMap* map = &((PyCUtensorMapObject*)obj)->tensorMap;
- uintptr_t align_128 = (uintptr_t)map & (128 - 1);
- if (align_128 != 0) {{
- PyErr_Format(PyExc_ValueError, "CUtensorMap must be aligned to 128B, but got (&map) mod 128 = %ld", align_128);
- return NULL;
- }}
- return map;
- }}
- static void ensureCudaContext() {{
- CUcontext pctx;
- CUDA_CHECK(cuCtxGetCurrent(&pctx));
- if (!pctx) {{
- // Ensure device context.
- CUdevice device;
- CUDA_CHECK(cuDeviceGet(&device, 0));
- CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
- CUDA_CHECK(cuCtxSetCurrent(pctx));
- }}
- }}
- static uint16_t pack_fp16(double f) {{
- uint16_t result;
- // from https://github.com/python/pythoncapi-compat
- #if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
- _PyFloat_Pack2(f, (unsigned char*)&result, 1);
- #else
- PyFloat_Pack2(f, (unsigned char*)&result, 1);
- #endif
- return result;
- }}
- static uint16_t pack_bf16(double f) {{
- float f32 = (float)f;
- uint32_t u32 = *(uint32_t*)&f32;
- return (uint16_t)(u32 >> 16);
- }}
- static uint32_t pack_fp32(double f) {{
- float f32 = (float)f;
- return *(uint32_t*)&f32;
- }}
- static uint64_t pack_fp64(double f) {{
- return *(uint64_t*)&f;
- }}
- static PyObject* launch(PyObject* self, PyObject* args) {{
- // ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes
- ensureCudaContext();
- int gridX, gridY, gridZ;
- uint64_t _stream;
- uint64_t _function;
- int launch_cooperative_grid;
- int launch_pdl;
- PyObject *launch_enter_hook = NULL;
- PyObject *launch_exit_hook = NULL;
- PyObject *kernel_metadata = NULL;
- PyObject *launch_metadata = NULL;
- PyObject *global_scratch_obj = NULL;
- PyObject *profile_scratch_obj = NULL;
- {newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
- if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
- &_stream, &_function, &launch_cooperative_grid, &launch_pdl, &global_scratch_obj, &profile_scratch_obj,
- &kernel_metadata, &launch_metadata,
- &launch_enter_hook, &launch_exit_hook{args_list})) {{
- return NULL;
- }}
- int num_warps, num_ctas, shared_memory;
- if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
- PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple");
- return NULL;
- }}
- // extract launch metadata
- if (launch_enter_hook != Py_None){{
- PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
- if (!ret)
- return NULL;
- Py_DECREF(ret);
- }}
- CUdeviceptr global_scratch = 0;
- if (global_scratch_obj != Py_None) {{
- DevicePtrInfo global_scratch_info = getPointer(global_scratch_obj, -1);
- if (!global_scratch_info.valid) {{
- return NULL;
- }}
- global_scratch = global_scratch_info.dev_ptr;
- }}
- CUdeviceptr profile_scratch = 0;
- if (profile_scratch_obj != Py_None) {{
- DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
- if (!profile_scratch_info.valid) {{
- return NULL;
- }}
- profile_scratch = profile_scratch_info.dev_ptr;
- }}
- // raise exception asap
- {newline.join(ptr_decls)}
- {newline.join(tma_decls)}
- {newline.join(float_storage_decls)}
- Py_BEGIN_ALLOW_THREADS;
- _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
- Py_END_ALLOW_THREADS;
- if (PyErr_Occurred()) {{
- return NULL;
- }}
- if(launch_exit_hook != Py_None){{
- PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
- if (!ret)
- return NULL;
- Py_DECREF(ret);
- }}
- Py_RETURN_NONE;
- }}
- static PyMethodDef ModuleMethods[] = {{
- {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
- {{NULL, NULL, 0, NULL}} // sentinel
- }};
- static struct PyModuleDef ModuleDef = {{
- PyModuleDef_HEAD_INIT,
- \"__triton_launcher\",
- NULL, //documentation
- -1, //size
- ModuleMethods
- }};
- PyMODINIT_FUNC PyInit___triton_launcher(void) {{
- data_ptr_str = PyUnicode_InternFromString("data_ptr");
- if(data_ptr_str == NULL) {{
- return NULL;
- }}
- PyObject* driver_mod = PyImport_ImportModule("triton.backends.nvidia.driver");
- if (driver_mod == NULL) {{
- return NULL;
- }}
- py_tensor_map_type = PyObject_GetAttrString(driver_mod, "PyCUtensorMap");
- if (py_tensor_map_type == NULL) {{
- return NULL;
- }}
- PyObject *m = PyModule_Create(&ModuleDef);
- if(m == NULL) {{
- return NULL;
- }}
- PyModule_AddFunctions(m, ModuleMethods);
- return m;
- }}
- """
- return src
- # The TMA dtype enum values are slightly different on host vs device...
- TMA_DTYPE_DEVICE_TO_HOST = dict((i, i) for i in range(16))
- TMA_DTYPE_DEVICE_TO_HOST[8] = 10
- TMA_DTYPE_DEVICE_TO_HOST[9] = 8
- TMA_DTYPE_DEVICE_TO_HOST[10] = 9
- def make_tensordesc_arg(arg, metadata):
- if metadata is None:
- # Currently the host side tensor descriptors get decomposed in
- # the frontend to tensor desc, shape, and strides. We have no
- # way to use these shape and strides when processing tensor
- # descriptors which is why we provide our own decomposition
- # above. Sadly this means we have to pass the shape and strides
- # twice.
- return [arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides]
- swizzle = metadata["swizzle"]
- elem_size = metadata["elem_size"]
- elem_type = metadata["elem_type"]
- block_size = metadata["block_size"]
- fp4_padded = metadata["fp4_padded"]
- shape = arg.shape
- strides = arg.strides
- assert strides[-1] == 1
- padding = 1 if arg.padding == "nan" else 0
- if fp4_padded:
- shape = list(shape)
- shape[-1] *= 2
- cu_tensor_map = triton.runtime.driver.active.utils.fill_tma_descriptor(
- arg.base.data_ptr(),
- swizzle,
- elem_size,
- TMA_DTYPE_DEVICE_TO_HOST[elem_type],
- block_size,
- shape,
- strides,
- padding,
- )
- return [cu_tensor_map, *shape, *strides]
- def wrap_handle_tensordesc(launcher, signature, tensordesc_meta):
- has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
- if not has_tensor_desc_arg:
- return launcher
- tensordesc_indices = set(
- [i for i, sig in enumerate(signature.values()) if isinstance(sig, str) and sig.startswith("tensordesc")])
- assert not tensordesc_meta or len(tensordesc_meta) == len(tensordesc_indices)
- if not tensordesc_meta:
- tensordesc_meta = [None] * len(tensordesc_indices)
- def inner(*args):
- final_args = list(args[:_BASE_ARGS_FORMAT_LEN])
- tensordesc_idx = 0
- for i, arg in enumerate(args[_BASE_ARGS_FORMAT_LEN:]):
- if i in tensordesc_indices:
- final_args.extend(make_tensordesc_arg(arg, tensordesc_meta[tensordesc_idx]))
- tensordesc_idx += 1
- else:
- final_args.append(arg)
- return launcher(*final_args)
- return inner
- class CudaLauncher(object):
- def __init__(self, src, metadata):
- constants = src.constants if hasattr(src, "constants") else dict()
- arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
- constants = {arg_idx(idx): value for idx, value in constants.items()}
- signature = {idx: value for idx, value in src.signature.items()}
- tensordesc_meta = getattr(metadata, "tensordesc_meta", None)
- src = make_launcher(constants, signature, tensordesc_meta)
- mod = compile_module_from_src(
- src=src,
- name="__triton_launcher",
- library_dirs=library_dirs(),
- include_dirs=include_dirs,
- libraries=libraries,
- )
- self.num_ctas = getattr(metadata, "num_ctas", 1)
- self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta)
- self.global_scratch_size = metadata.global_scratch_size
- self.global_scratch_align = metadata.global_scratch_align
- self.profile_scratch_size = metadata.profile_scratch_size
- self.profile_scratch_align = metadata.profile_scratch_align
- self.launch_cooperative_grid = metadata.launch_cooperative_grid
- self.launch_pdl = metadata.launch_pdl
- def __call__(self, gridX, gridY, gridZ, stream, function, *args):
- def allocate_scratch(size, align, allocator):
- if size > 0:
- grid_size = gridX * gridY * gridZ
- alloc_size = grid_size * self.num_ctas * size
- alloc_fn = allocator.get()
- return alloc_fn(alloc_size, align, stream)
- return None
- global_scratch = allocate_scratch(self.global_scratch_size, self.global_scratch_align, _allocation._allocator)
- profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
- _allocation._profile_allocator)
- self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
- global_scratch, profile_scratch, *args)
- class CudaDriver(GPUDriver):
- def __init__(self):
- self.utils = CudaUtils() # TODO: make static
- self.launcher_cls = CudaLauncher
- super().__init__()
- def get_current_target(self):
- device = self.get_current_device()
- capability = self.get_device_capability(device)
- capability = capability[0] * 10 + capability[1]
- warp_size = 32
- return GPUTarget("cuda", capability, warp_size)
- def get_active_torch_device(self):
- import torch
- return torch.device("cuda", self.get_current_device())
- def get_device_interface(self):
- import torch
- return torch.cuda
- @staticmethod
- def is_active():
- try:
- import torch
- return torch.cuda.is_available() and (torch.version.hip is None)
- except ImportError:
- return False
- def map_python_to_cpp_type(self, ty: str) -> str:
- return ty_to_cpp(ty)
- def get_benchmarker(self):
- from triton.testing import do_bench
- return do_bench
- def get_empty_cache_for_benchmark(self):
- import torch
- # We maintain a buffer of 256 MB that we clear
- # before each kernel call to make sure that the L2 cache
- # doesn't contain any input data before the run
- cache_size = 256 * 1024 * 1024
- return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
- def clear_cache(self, cache):
- cache.zero_()
|