| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978 |
- import functools
- import os
- import platform
- import subprocess
- import re
- import triton
- from pathlib import Path
- from triton import knobs
- from triton.backends.compiler import GPUTarget
- from triton.backends.driver import GPUDriver
- from triton.runtime import _allocation
- from triton.runtime.build import compile_module_from_src
- dirname = os.path.dirname(os.path.realpath(__file__))
- include_dirs = [os.path.join(dirname, "include")]
- PyTDMDescriptor = None
- def _is_windows():
- return platform.system() == 'Windows'
- def _get_rocm_sdk_root():
- """Get ROCm SDK root path using rocm-sdk command or environment variables."""
- # Try rocm-sdk path --root first (for Windows ROCm SDK)
- try:
- result = subprocess.check_output(["rocm-sdk", "path", "--root"], stderr=subprocess.DEVNULL)
- root = result.decode().strip()
- if root and os.path.isdir(root):
- return root
- except (subprocess.CalledProcessError, FileNotFoundError):
- pass
- # Fall back to environment variables
- for env_var in ["ROCM_HOME", "HIP_PATH", "ROCM_PATH"]:
- path = os.environ.get(env_var, "")
- if path and os.path.isdir(path):
- return path
- return None
- def _get_hip_library_from_rocm_sdk():
- """Get the amdhip64 library path using rocm_sdk.find_libraries."""
- try:
- import rocm_sdk
- paths = rocm_sdk.find_libraries("amdhip64")
- if paths:
- return str(paths[0])
- except (ImportError, ModuleNotFoundError, FileNotFoundError):
- pass
- return None
- # Add HIP runtime headers from ROCm SDK if available
- _rocm_root = _get_rocm_sdk_root()
- if _rocm_root and os.path.isdir(os.path.join(_rocm_root, "include")):
- include_dirs.append(os.path.join(_rocm_root, "include"))
- def _find_already_mmapped_dylib_on_linux(lib_name):
- if platform.system() != 'Linux':
- return None
- # Use dl_iterate_phdr to walk through the list of shared libraries at runtime.
- # See https://www.man7.org/linux/man-pages/man3/dl_iterate_phdr.3.html for details.
- import ctypes
- from ctypes import c_char, c_int, c_size_t, c_void_p, c_char_p, POINTER
- class DlPhdrInfo(ctypes.Structure):
- _fields_ = [
- ('dlpi_addr', c_void_p),
- ('dlpi_name', c_char_p),
- # We don't care about the remaining fields.
- ]
- # callback_t must use POINTER(c_char) to avoid copying.
- callback_t = ctypes.CFUNCTYPE(c_int, POINTER(DlPhdrInfo), POINTER(c_size_t), POINTER(c_char))
- # Load libc and get the dl_iterate_phdr symbol.
- try:
- dl_iterate_phdr = ctypes.CDLL('libc.so.6').dl_iterate_phdr
- except Exception:
- return None
- # argtypes must use c_char_p to accept create_string_buffer.
- dl_iterate_phdr.argtypes = [callback_t, c_char_p]
- dl_iterate_phdr.restype = c_int
- max_path_length = 4096
- path = ctypes.create_string_buffer(max_path_length + 1)
- # Define callback to get the loaded dylib path.
- def callback(info, size, data):
- dlpi_name = info.contents.dlpi_name
- p = Path(os.fsdecode(dlpi_name))
- if lib_name in p.name:
- # Found the dylib; get its path.
- ctypes.memmove(data, dlpi_name, min(max_path_length, len(dlpi_name)))
- return 1
- return 0
- if dl_iterate_phdr(callback_t(callback), path):
- return os.fsdecode(ctypes.string_at(path))
- return None
- @functools.lru_cache()
- def _get_path_to_hip_runtime_dylib():
- lib_name = "amdhip64.dll" if _is_windows() else "libamdhip64.so"
- # If we are told explicitly what HIP runtime dynamic library to use, obey that.
- if env_libhip_path := knobs.amd.libhip_path:
- if env_libhip_path.endswith(lib_name) and os.path.exists(env_libhip_path):
- return env_libhip_path
- raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}")
- # Try rocm_sdk.find_libraries first - this is the preferred method
- rocm_sdk_path = _get_hip_library_from_rocm_sdk()
- if rocm_sdk_path:
- return rocm_sdk_path
- # If the shared object is already mmapped to address space, use it (Linux only).
- if not _is_windows():
- mmapped_path = _find_already_mmapped_dylib_on_linux(lib_name)
- if mmapped_path:
- if os.path.exists(mmapped_path):
- return mmapped_path
- raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}")
- paths = []
- # Check backend
- local_lib = os.path.join(os.path.dirname(__file__), "lib", lib_name)
- if os.path.exists(local_lib):
- return local_lib
- paths.append(local_lib)
- import site
- # First search the HIP runtime dynamic library packaged with PyTorch. It's very likely
- # that we run Triton together with PyTorch. This makes sure we use the same dynamic
- # library to avoid version mismatch.
- site_packages = site.getsitepackages()
- user_site = site.getusersitepackages()
- if site.ENABLE_USER_SITE: # ENABLE_USER_SITE is initialized in getusersitepackages()
- site_packages = [user_site] + site_packages
- for path in site_packages:
- path = os.path.join(path, "torch", "lib", lib_name)
- if os.path.exists(path):
- return path
- paths.append(path)
- # Then try to see if developer provides a HIP runtime dynamic library using LD_LIBRARY_PATH (Linux) or PATH (Windows).
- if _is_windows():
- env_path = os.getenv("PATH", "")
- path_sep = ";"
- else:
- env_path = os.getenv("LD_LIBRARY_PATH", "")
- path_sep = ":"
- if env_path:
- for d in env_path.split(path_sep):
- f = os.path.join(d, lib_name)
- if os.path.exists(f):
- return f
- paths.append(f)
- # HIP_PATH should point to HIP SDK root if set
- env_hip_path = os.getenv("HIP_PATH")
- if env_hip_path:
- # On Windows, DLLs are in bin; on Linux, .so files are in lib
- lib_subdir = "bin" if _is_windows() else "lib"
- hip_lib_path = os.path.join(env_hip_path, lib_subdir, lib_name)
- if os.path.exists(hip_lib_path):
- return hip_lib_path
- paths.append(hip_lib_path)
- # Try rocm-sdk path --root (Windows ROCm SDK) or hipconfig --path (Linux)
- lib_subdir = "bin" if _is_windows() else "lib"
- try:
- if _is_windows():
- rocm_root = subprocess.check_output(["rocm-sdk", "path", "--root"],
- stderr=subprocess.DEVNULL).decode().strip()
- else:
- rocm_root = subprocess.check_output(["hipconfig", "--path"]).decode().strip()
- if rocm_root:
- rocm_lib_path = os.path.join(rocm_root, lib_subdir, lib_name)
- if os.path.exists(rocm_lib_path):
- return rocm_lib_path
- paths.append(rocm_lib_path)
- except (subprocess.CalledProcessError, FileNotFoundError):
- # rocm-sdk or hipconfig may not be available
- pass
- # ROCm lib dir based on env var
- env_rocm_path = os.getenv("ROCM_PATH") or os.getenv("ROCM_HOME")
- if env_rocm_path:
- rocm_lib_path = os.path.join(env_rocm_path, lib_subdir, lib_name)
- if os.path.exists(rocm_lib_path):
- return rocm_lib_path
- paths.append(rocm_lib_path)
- # Afterwards try to search the loader dynamic library resolution paths (Linux only).
- if not _is_windows():
- try:
- libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
- # each line looks like the following:
- # libamdhip64.so.6 (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so.6
- # libamdhip64.so (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so
- locs = [line.split()[-1] for line in libs.splitlines() if line.strip().endswith(lib_name)]
- for loc in locs:
- if os.path.exists(loc):
- return loc
- paths.append(loc)
- except (subprocess.CalledProcessError, FileNotFoundError):
- pass
- # As a last resort on Linux, guess if we have it in some common installation path.
- common_install_path = os.path.join('/opt/rocm/lib/', lib_name)
- if os.path.exists(common_install_path):
- return common_install_path
- paths.append(common_install_path)
- raise RuntimeError(f"cannot locate {lib_name} after attempted paths {paths}")
- class HIPUtils(object):
- def __new__(cls):
- if not hasattr(cls, "instance"):
- cls.instance = super(HIPUtils, cls).__new__(cls)
- return cls.instance
- def __init__(self):
- libhip_path = _get_path_to_hip_runtime_dylib()
- # Escape backslashes for C string embedding
- libhip_path_escaped = libhip_path.replace("\\", "\\\\")
- src = Path(os.path.join(dirname, "driver.c")).read_text()
- # Just do a simple search and replace here instead of templates or format strings.
- # This way we don't need to escape-quote C code curly brackets and we can replace
- # exactly once.
- src = src.replace('/*py_libhip_search_path*/', libhip_path_escaped, 1)
- mod = compile_module_from_src(src=src, name="hip_utils", include_dirs=include_dirs)
- self.load_binary = mod.load_binary
- self.get_device_properties = mod.get_device_properties
- self.create_tdm_descriptor = mod.create_tdm_descriptor
- global PyTDMDescriptor
- PyTDMDescriptor = mod.PyTDMDescriptor
- # -------------------- Launcher ----------------------------
- def ty_to_cpp(ty):
- if ty.startswith('*'):
- return "hipDeviceptr_t"
- if ty == "tensordesc":
- return "TDMDescriptor"
- return {
- "i1": "int8_t",
- "i8": "int8_t",
- "i16": "int16_t",
- "i32": "int32_t",
- "i64": "int64_t",
- "u1": "uint8_t",
- "u8": "uint8_t",
- "u16": "uint16_t",
- "u32": "uint32_t",
- "u64": "uint64_t",
- "fp16": "double",
- "bf16": "double",
- "fp32": "double",
- "f32": "double",
- "fp64": "double",
- }[ty]
- FLOAT_STORAGE_TYPE = {
- "fp16": "uint16_t",
- "bf16": "uint16_t",
- "fp32": "uint32_t",
- "f32": "uint32_t",
- "fp64": "uint64_t",
- }
- FLOAT_PACK_FUNCTION = {
- "fp16": "pack_fp16",
- "bf16": "pack_bf16",
- "fp32": "pack_fp32",
- "f32": "pack_fp32",
- "fp64": "pack_fp64",
- }
- _BASE_ARGS_FORMAT = "piiiKKOOOOO"
- def make_launcher(constants, signature, warp_size, tensordesc_meta):
- def _expand_signature(signature):
- output = []
- tensordesc_idx = 0
- for sig in signature:
- if isinstance(sig, str) and sig.startswith("tensordesc"):
- meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
- tensordesc_idx += 1
- match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", sig)
- dtype = match.group(1)
- shape = match.group(2)
- ndim = shape.count(",") + 1
- # If there is no descriptor's metadata, the descriptor has been decomposed to base pointer, shape and strides
- if meta is None:
- output.append("*" + dtype)
- for _ in range(2 * ndim):
- output.append("i64")
- output.append("i1")
- else:
- output.append("tensordesc")
- for _ in range(ndim):
- output.append("i32")
- for _ in range(ndim):
- output.append("i64")
- else:
- output.append(sig)
- return output
- def _serialize_signature(sig):
- if isinstance(sig, tuple):
- return ','.join(map(_serialize_signature, sig))
- return sig
- def _extracted_type(ty):
- if isinstance(ty, tuple):
- val = ','.join(map(_extracted_type, ty))
- return f"[{val}]"
- if ty.startswith("*") or ty.startswith("tensordesc"):
- return "PyObject*"
- if ty == "constexpr":
- return "PyObject*"
- return ty_to_cpp(ty)
- def format_of(ty):
- if isinstance(ty, tuple):
- val = ''.join(map(format_of, ty))
- return f"({val})"
- if ty.startswith("*") or ty.startswith("tensordesc"):
- return "O"
- if ty == "constexpr":
- return "O"
- return {
- "double": "d",
- "long": "l",
- "int8_t": "b",
- "int16_t": "h",
- "int32_t": "i",
- "int64_t": "L",
- "uint8_t": "B",
- "uint16_t": "H",
- "uint32_t": "I",
- "uint64_t": "K",
- }[ty_to_cpp(ty)]
- signature = {idx: s for idx, s in enumerate(_expand_signature(signature.values()))}
- args_format = ''.join([format_of(ty) for ty in signature.values()])
- format = _BASE_ARGS_FORMAT + args_format
- signature = ','.join(map(_serialize_signature, signature.values()))
- signature = list(filter(bool, signature.split(',')))
- signature = {i: s for i, s in enumerate(signature)}
- args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
- # Record the end of regular arguments;
- # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
- arg_decl_list = []
- for i, ty in signature.items():
- if ty == "constexpr":
- continue
- if ty in FLOAT_STORAGE_TYPE:
- arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
- else:
- arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
- arg_decls = ', '.join(arg_decl_list)
- internal_args_list = []
- for i, ty in signature.items():
- if ty.startswith("*"):
- internal_args_list.append(f"ptr_info{i}.dev_ptr")
- elif ty.startswith("tensordesc"):
- internal_args_list.append(f"*desc{i}")
- elif ty in FLOAT_STORAGE_TYPE:
- internal_args_list.append(f"_arg{i}_storage")
- elif ty != "constexpr":
- internal_args_list.append(f"_arg{i}")
- newline = '\n '
- ptr_decls = [
- f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;"
- for i, ty in signature.items()
- if ty.startswith("*")
- ]
- tensor_desc_decls = [
- f"TDMDescriptor* desc{i} = getTDMDescriptor(_arg{i}, {i});" for i, ty in signature.items()
- if ty.startswith("tensordesc")
- ]
- float_storage_decls = [
- f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
- for i, ty in signature.items()
- if ty in FLOAT_STORAGE_TYPE
- ]
- libhip_path = _get_path_to_hip_runtime_dylib()
- # Escape backslashes for C string embedding
- libhip_path = libhip_path.replace("\\", "\\\\")
- # generate glue code
- params = list(range(len(signature)))
- params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
- params.append("&global_scratch")
- params.append("&profile_scratch")
- # Platform-specific includes and dlopen/dlsym macros
- if _is_windows():
- platform_includes = """
- #define __HIP_PLATFORM_AMD__
- #include <hip/hip_runtime.h>
- #include <hip/hip_runtime_api.h>
- #include <Python.h>
- #include <windows.h>
- #include <stdbool.h>
- // Windows compatibility layer for dlopen/dlsym/dlclose
- static char _dlerror_buf[512];
- static inline void *dlopen(const char *filename, int flags) {
- (void)flags;
- HMODULE h = LoadLibraryA(filename);
- if (!h) {
- snprintf(_dlerror_buf, sizeof(_dlerror_buf), "LoadLibrary failed with error %lu", GetLastError());
- }
- return (void *)h;
- }
- static inline void *dlsym(void *handle, const char *symbol) {
- void *p = (void *)GetProcAddress((HMODULE)handle, symbol);
- if (!p) {
- snprintf(_dlerror_buf, sizeof(_dlerror_buf), "GetProcAddress failed for %s with error %lu", symbol, GetLastError());
- }
- return p;
- }
- static inline int dlclose(void *handle) { return FreeLibrary((HMODULE)handle) ? 0 : -1; }
- static inline const char *dlerror(void) { return _dlerror_buf[0] ? _dlerror_buf : NULL; }
- #define RTLD_LAZY 0
- #define RTLD_LOCAL 0
- #define RTLD_NOLOAD 0
- """
- else:
- platform_includes = """
- #define __HIP_PLATFORM_AMD__
- #include <hip/hip_runtime.h>
- #include <hip/hip_runtime_api.h>
- #include <Python.h>
- #include <dlfcn.h>
- #include <stdbool.h>
- """
- src = f"""{platform_includes}
- typedef struct {{
- uint32_t group0_0;
- uint32_t group0_1;
- uint32_t group0_2;
- uint32_t group0_3;
- uint32_t group1_0;
- uint32_t group1_1;
- uint32_t group1_2;
- uint32_t group1_3;
- uint32_t group1_4;
- uint32_t group1_5;
- uint32_t group1_6;
- uint32_t group1_7;
- }} TDMDescriptor;
- typedef struct {{
- PyObject_HEAD;
- TDMDescriptor desc;
- }} PyTDMDescriptorObject;
- // The list of paths to search for the HIP runtime library. The caller Python
- // code should substitute the search path placeholder.
- static const char *hipLibSearchPaths[] = {{"{libhip_path}"}};
- // The list of HIP dynamic library symbols and their signature we are interested
- // in this file.
- #define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \\
- FOR_EACH_STR_FN(hipGetLastError, true) \\
- FOR_EACH_STR_FN(hipGetErrorString, true, hipError_t hipError) \\
- FOR_EACH_ERR_FN(hipDrvLaunchKernelEx, false, \\
- const HIP_LAUNCH_CONFIG *config, \\
- hipFunction_t f, \\
- void **kernelParams, \\
- void **extra) \\
- FOR_EACH_ERR_FN(hipModuleLaunchKernel, true, hipFunction_t f, \\
- unsigned int gridDimX, unsigned int gridDimY, \\
- unsigned int gridDimZ, unsigned int blockDimX, \\
- unsigned int blockDimY, unsigned int blockDimZ, \\
- unsigned int sharedMemBytes, hipStream_t stream, \\
- void **kernelParams, void **extra) \\
- FOR_EACH_ERR_FN(hipModuleLaunchCooperativeKernel, true, hipFunction_t f, \\
- unsigned int gridDimX, unsigned int gridDimY, \\
- unsigned int gridDimZ, unsigned int blockDimX, \\
- unsigned int blockDimY, unsigned int blockDimZ, \\
- unsigned int sharedMemBytes, hipStream_t stream, \\
- void **kernelParams, void **extra) \\
- FOR_EACH_ERR_FN(hipPointerGetAttribute, true, void *data, \\
- hipPointer_attribute attribute, hipDeviceptr_t ptr)
- // The HIP symbol table for holding resolved dynamic library symbols.
- struct HIPSymbolTable {{
- #define DEFINE_EACH_ERR_FIELD(hipSymbolName, required, ...) \\
- hipError_t (*hipSymbolName)(__VA_ARGS__);
- #define DEFINE_EACH_STR_FIELD(hipSymbolName, required, ...) \\
- const char *(*hipSymbolName)(__VA_ARGS__);
- HIP_SYMBOL_LIST(DEFINE_EACH_ERR_FIELD, DEFINE_EACH_STR_FIELD)
- }};
- static struct HIPSymbolTable hipSymbolTable;
- bool initSymbolTable() {{
- void *lib = NULL;
- // Go through the list of search paths to open the first HIP driver library.
- int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
- for (int i = 0; i < n; ++i) {{
- void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
- if (handle) {{
- lib = handle;
- break;
- }}
- }}
- if (!lib) {{
- PyErr_SetString(PyExc_RuntimeError, "cannot open HIP runtime library");
- return false;
- }}
- typedef hipError_t (*hipGetProcAddress_fn)(
- const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
- hipDriverProcAddressQueryResult *symbolStatus);
- hipGetProcAddress_fn hipGetProcAddress;
- dlerror(); // Clear existing errors
- const char *error = NULL;
- *(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
- error = dlerror();
- if (error) {{
- PyErr_SetString(PyExc_RuntimeError,
- "cannot query 'hipGetProcAddress' from HIP runtime library");
- dlclose(lib);
- return false;
- }}
- // Resolve all symbols we are interested in.
- int hipVersion = HIP_VERSION;
- uint64_t hipFlags = 0;
- hipDriverProcAddressQueryResult symbolStatus;
- hipError_t status = hipSuccess;
- #define QUERY_EACH_FN(hipSymbolName, required, ...) \
- status = hipGetProcAddress(#hipSymbolName, \
- (void **)&hipSymbolTable.hipSymbolName, \
- hipVersion, hipFlags, &symbolStatus); \
- if (required && status != hipSuccess) {{ \
- PyErr_SetString(PyExc_RuntimeError, \
- "cannot get address for '" #hipSymbolName \
- "' from libamdhip64.so"); \
- dlclose(lib); \
- return false; \
- }}
- HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)
- return true;
- }}
- static inline void gpuAssert(hipError_t code, const char *file, int line)
- {{
- if (code != HIP_SUCCESS)
- {{
- const char* prefix = "Triton Error [HIP]: ";
- const char* str = hipSymbolTable.hipGetErrorString(code);
- char err[1024] = {{0}};
- snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
- PyErr_SetString(PyExc_RuntimeError, err);
- }}
- }}
- #define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
- static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
- if (gridX * gridY * gridZ == 0)
- return;
- hipDeviceptr_t global_scratch = 0;
- void *params[] = {{ {', '.join(params)} }};
- if(num_ctas > 1) {{
- if (!hipSymbolTable.hipDrvLaunchKernelEx) {{
- PyErr_SetString(PyExc_RuntimeError, "missing hipDrvLaunchKernelEx symbol; please update HIP runtime");
- return;
- }}
- hipLaunchAttribute attributes[2];
- // Attribute0: Cluster dimensions
- attributes[0].id = 4;
- int *cluster_dims = (int*)attributes[0].val.pad;
- cluster_dims[0] = num_ctas;
- cluster_dims[1] = 1;
- cluster_dims[2] = 1;
- // Attribute1: Cooperative launch
- attributes[1].id = hipLaunchAttributeCooperative;
- attributes[1].val.cooperative = launch_cooperative_grid;
- HIP_LAUNCH_CONFIG config = {{
- gridX * num_ctas, gridY, gridZ, // Grid size
- {warp_size} * num_warps, 1, 1, // Block size
- shared_memory, stream,
- attributes, 2 // Number of attributes
- }};
- HIP_CHECK(hipSymbolTable.hipDrvLaunchKernelEx(&config, function, params, 0));
- return;
- }}
- else if (launch_cooperative_grid) {{
- HIP_CHECK(hipSymbolTable.hipModuleLaunchCooperativeKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
- return;
- }}
- else {{
- HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
- }}
- }}
- typedef struct _DevicePtrInfo {{
- hipDeviceptr_t dev_ptr;
- bool valid;
- }} DevicePtrInfo;
- static PyObject* data_ptr_str = NULL;
- static PyObject* py_tdm_descriptor_type = NULL;
- static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
- DevicePtrInfo ptr_info;
- hipError_t status = hipSuccess;
- ptr_info.dev_ptr = 0;
- ptr_info.valid = true;
- if (PyLong_Check(obj)) {{
- ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
- return ptr_info;
- }}
- if (obj == Py_None) {{
- // valid nullptr
- return ptr_info;
- }}
- PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
- if (!ret) {{
- PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
- ptr_info.valid = false;
- goto cleanup;
- }}
- if (!PyLong_Check(ret)) {{
- PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
- ptr_info.valid = false;
- goto cleanup;
- }}
- ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
- if (!ptr_info.dev_ptr)
- goto cleanup;
- uint64_t dev_ptr;
- status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
- if (status == hipErrorInvalidValue) {{
- PyErr_Format(PyExc_ValueError,
- "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
- ptr_info.valid = false;
- // Clear and ignore HIP error
- (void)hipSymbolTable.hipGetLastError();
- }}
- ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
- cleanup:
- Py_DECREF(ret);
- return ptr_info;
- }}
- static inline TDMDescriptor* getTDMDescriptor(PyObject* obj, int idx) {{
- if (Py_TYPE(obj) != (PyTypeObject*)py_tdm_descriptor_type) {{
- PyErr_Format(PyExc_TypeError, "object must be of type PyTDMDescriptor, got %s", Py_TYPE(obj)->tp_name);
- return NULL;
- }}
- TDMDescriptor* desc = &((PyTDMDescriptorObject*)obj)->desc;
- return desc;
- }}
- static uint16_t pack_fp16(double f) {{
- uint16_t result;
- // from https://github.com/python/pythoncapi-compat/blob/5e317108f872c904eb726cb8d560dcadbdf88a72/pythoncapi_compat.h#L482-L492
- #if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
- _PyFloat_Pack2(f, (unsigned char*)&result, 1);
- #else
- PyFloat_Pack2(f, (char*)&result, 1);
- #endif
- return result;
- }}
- static uint16_t pack_bf16(double f) {{
- float f32 = (float)f;
- uint32_t u32 = *(uint32_t*)&f32;
- return (uint16_t)(u32 >> 16);
- }}
- static uint32_t pack_fp32(double f) {{
- float f32 = (float)f;
- return *(uint32_t*)&f32;
- }}
- static uint64_t pack_fp64(double f) {{
- return *(uint64_t*)&f;
- }}
- static PyObject* launch(PyObject* self, PyObject* args) {{
- int gridX, gridY, gridZ;
- uint64_t _stream;
- uint64_t _function;
- int launch_cooperative_grid;
- PyObject *profile_scratch_obj = NULL;
- PyObject *launch_enter_hook = NULL;
- PyObject *launch_exit_hook = NULL;
- PyObject *kernel_metadata = NULL;
- PyObject *launch_metadata = NULL;
- {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
- if(!PyArg_ParseTuple(args, \"{format}\", &launch_cooperative_grid,
- &gridX, &gridY, &gridZ, &_stream, &_function, &profile_scratch_obj,
- &kernel_metadata, &launch_metadata,
- &launch_enter_hook, &launch_exit_hook {args_list})) {{
- return NULL;
- }}
- // extract kernel metadata
- int num_warps, num_ctas, shared_memory;
- if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
- return NULL;
- }}
- // extract launch metadata
- if (launch_enter_hook != Py_None){{
- PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
- if (!ret)
- return NULL;
- Py_DECREF(ret);
- }}
- hipDeviceptr_t profile_scratch = 0;
- if (profile_scratch_obj != Py_None) {{
- DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
- if (!profile_scratch_info.valid) {{
- return NULL;
- }}
- profile_scratch = profile_scratch_info.dev_ptr;
- }}
- // raise exception asap
- {newline.join(tensor_desc_decls)}
- {newline.join(ptr_decls)}
- {newline.join(float_storage_decls)}
- _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
- if(launch_exit_hook != Py_None){{
- PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
- if (!ret)
- return NULL;
- Py_DECREF(ret);
- }}
- if(PyErr_Occurred()) {{
- return NULL;
- }}
- Py_RETURN_NONE;
- }}
- static PyMethodDef ModuleMethods[] = {{
- {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
- {{NULL, NULL, 0, NULL}} // sentinel
- }};
- static struct PyModuleDef ModuleDef = {{
- PyModuleDef_HEAD_INIT,
- \"__triton_launcher\",
- NULL, //documentation
- -1, //size
- ModuleMethods
- }};
- PyMODINIT_FUNC PyInit___triton_launcher(void) {{
- if (!initSymbolTable()) {{
- return NULL;
- }}
- PyObject *m = PyModule_Create(&ModuleDef);
- if(m == NULL) {{
- return NULL;
- }}
- data_ptr_str = PyUnicode_InternFromString("data_ptr");
- if(data_ptr_str == NULL) {{
- return NULL;
- }}
- PyObject* driver_mod = PyImport_ImportModule("triton.backends.amd.driver");
- if (driver_mod == NULL) {{
- return NULL;
- }}
- py_tdm_descriptor_type = PyObject_GetAttrString(driver_mod, "PyTDMDescriptor");
- if (py_tdm_descriptor_type == NULL) {{
- return NULL;
- }}
- PyModule_AddFunctions(m, ModuleMethods);
- return m;
- }}
- """
- return src
- def make_tensordesc_arg(arg, kernel_metadata, tensordesc_metadata):
- """
- Translate a tensor descriptor argument into the appropriate list of kernel
- arguments. If `tensordesc_metadata` is provided, we will create a
- TDMDescriptor object. Otherwise, we decompose the tensor descriptor into
- base pointer, shape, strides, and padding flag. In both cases, we append the
- shape and strides at the end to match the expected kernel signature.
- """
- if tensordesc_metadata is None:
- # Currently the host side tensor descriptors get decomposed in
- # the frontend to tensor desc, shape, and strides. We have no
- # way to use these shape and strides when processing tensor
- # descriptors which is why we provide our own decomposition
- # above. Sadly this means we have to pass the shape and strides
- # twice.
- return [arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides]
- shape = arg.shape
- strides = arg.strides
- base = arg.base.data_ptr()
- assert "elem_bits" in tensordesc_metadata and "block_size" in tensordesc_metadata
- elem_bits = tensordesc_metadata["elem_bits"]
- block_size = tensordesc_metadata["block_size"]
- pad_interval, pad_amount = 0, 0
- interval_padding_pairs = tensordesc_metadata.get("interval_padding_pairs", [])
- if interval_padding_pairs:
- assert len(interval_padding_pairs) == 1 and len(interval_padding_pairs[0]) == 2
- pad_interval, pad_amount = interval_padding_pairs[0]
- num_warps = kernel_metadata[0]
- driver = triton.runtime.driver.active
- assert isinstance(driver, HIPDriver)
- desc = driver.utils.create_tdm_descriptor(elem_bits, block_size, num_warps, pad_interval, pad_amount, shape,
- strides, base)
- return [desc, *shape, *strides]
- def wrap_handle_tensordesc(launcher, signature, tensordesc_metadata):
- """
- Wrap a kernel launcher function to handle tensor descriptor arguments.
- Use the provided `tensordesc_metadata` to determine whether to create
- TDMDescriptor objects or decompose the tensor descriptors.
- Args:
- launcher (callable): The original kernel launcher function.
- signature (Dict[int, str]): The kernel signature mapping argument indices to types.
- tensordesc_metadata (List[Dict] or None): The list of tensor descriptor metadata, following the order
- of tensor descriptor arguments. If None, decompose tensor descriptors.
- Returns:
- launcher (callable): The wrapped kernel launcher function.
- """
- has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
- if not has_tensor_desc_arg:
- return launcher
- tensordesc_indices = set(
- [i for i, sig in enumerate(signature.values()) if isinstance(sig, str) and sig.startswith("tensordesc")])
- assert not tensordesc_metadata or len(tensordesc_metadata) == len(tensordesc_indices)
- if not tensordesc_metadata:
- tensordesc_metadata = [None] * len(tensordesc_indices)
- def inner(*args):
- meta_args = args[:len(_BASE_ARGS_FORMAT)]
- raw_kernel_args = args[len(_BASE_ARGS_FORMAT):]
- final_args = []
- tensordesc_idx = 0
- for i, arg in enumerate(raw_kernel_args):
- if i in tensordesc_indices:
- tensordesc_args = make_tensordesc_arg(arg, meta_args[7], # kernel_metadata
- tensordesc_metadata[tensordesc_idx])
- final_args.extend(tensordesc_args)
- tensordesc_idx += 1
- else:
- final_args.append(arg)
- return launcher(*meta_args, *final_args)
- return inner
- class HIPLauncher(object):
- def __init__(self, src, metadata):
- constants = src.constants if hasattr(src, "constants") else dict()
- arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
- constants = {arg_idx(idx): value for idx, value in constants.items()}
- signature = {idx: value for idx, value in src.signature.items()}
- tensordesc_meta = getattr(metadata, "tensordesc_meta", None)
- src = make_launcher(constants, signature, metadata.warp_size, tensordesc_meta)
- mod = compile_module_from_src(src=src, name="__triton_launcher", include_dirs=include_dirs)
- self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta)
- self.launch_cooperative_grid = metadata.launch_cooperative_grid
- self.profile_scratch_size = metadata.profile_scratch_size
- self.profile_scratch_align = metadata.profile_scratch_align
- def __call__(self, gridX, gridY, gridZ, stream, function, *args):
- def allocate_scratch(size, align, allocator):
- if size > 0:
- grid_size = gridX * gridY * gridZ
- alloc_size = grid_size * size
- alloc_fn = allocator.get()
- return alloc_fn(alloc_size, align, stream)
- return None
- profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
- _allocation._profile_allocator)
- self.launch(self.launch_cooperative_grid, gridX, gridY, gridZ, stream, function, profile_scratch, *args)
- class HIPDriver(GPUDriver):
- def __init__(self):
- super().__init__()
- self.utils = HIPUtils()
- self.launcher_cls = HIPLauncher
- def get_device_interface(self):
- import torch
- return torch.cuda
- @staticmethod
- def is_active():
- try:
- import torch
- return torch.cuda.is_available() and (torch.version.hip is not None)
- except ImportError:
- return False
- def map_python_to_cpp_type(self, ty: str) -> str:
- return ty_to_cpp(ty)
- def get_current_target(self):
- device = self.get_current_device()
- device_properties = self.utils.get_device_properties(device)
- arch = knobs.runtime.override_arch or device_properties['arch']
- warp_size = device_properties['warpSize']
- return GPUTarget("hip", arch.split(':')[0], warp_size)
- def get_active_torch_device(self):
- import torch
- # when using hip devices, the device string in pytorch is "cuda"
- return torch.device("cuda", self.get_current_device())
- def get_benchmarker(self):
- from triton.testing import do_bench
- return do_bench
- def get_empty_cache_for_benchmark(self):
- import torch
- # It's the same as the Nvidia backend.
- cache_size = 256 * 1024 * 1024
- return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
- def clear_cache(self, cache):
- cache.zero_()
|